use std::sync::Arc;
use std::time::{Duration, Instant};
use dimpl::{Config, Dtls};
use crate::common::*;
#[test]
#[cfg(feature = "rcgen")]
fn dtls13_key_update_on_aead_limit() {
use dimpl::certificate::generate_self_signed_certificate;
let _ = env_logger::try_init();
let client_cert = generate_self_signed_certificate().expect("gen client cert");
let server_cert = generate_self_signed_certificate().expect("gen server cert");
let config = Arc::new(
Config::builder()
.aead_encryption_limit(10)
.build()
.expect("build config"),
);
let mut now = Instant::now();
let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now);
client.set_active(true);
let mut server = Dtls::new_13(config, server_cert, now);
server.set_active(false);
let mut client_connected = false;
let mut server_connected = false;
for _ in 0..30 {
client.handle_timeout(now).expect("client timeout");
server.handle_timeout(now).expect("server timeout");
let client_out = drain_outputs(&mut client);
let server_out = drain_outputs(&mut server);
client_connected |= client_out.connected;
server_connected |= server_out.connected;
deliver_packets(&client_out.packets, &mut server);
deliver_packets(&server_out.packets, &mut client);
if client_connected && server_connected {
break;
}
now += Duration::from_millis(50);
}
assert!(client_connected, "Client should connect");
assert!(server_connected, "Server should connect");
let mut server_received = 0;
for i in 0..100 {
let msg = format!("Message {}", i);
client
.send_application_data(msg.as_bytes())
.expect("send app data");
now += Duration::from_millis(10);
for _ in 0..3 {
client.handle_timeout(now).expect("client timeout");
let client_out = drain_outputs(&mut client);
deliver_packets(&client_out.packets, &mut server);
server.handle_timeout(now).expect("server timeout");
let server_out = drain_outputs(&mut server);
deliver_packets(&server_out.packets, &mut client);
server_received += server_out.app_data.len();
}
}
assert_eq!(
server_received, 100,
"All messages should be received (proves KeyUpdate worked transparently)"
);
}
#[test]
#[cfg(feature = "rcgen")]
fn dtls13_key_update_bidirectional_after_limit() {
use dimpl::certificate::generate_self_signed_certificate;
let _ = env_logger::try_init();
let client_cert = generate_self_signed_certificate().expect("gen client cert");
let server_cert = generate_self_signed_certificate().expect("gen server cert");
let config = Arc::new(
Config::builder()
.aead_encryption_limit(10)
.build()
.expect("build config"),
);
let mut now = Instant::now();
let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now);
client.set_active(true);
let mut server = Dtls::new_13(config, server_cert, now);
server.set_active(false);
let mut client_connected = false;
let mut server_connected = false;
for _ in 0..30 {
client.handle_timeout(now).expect("client timeout");
server.handle_timeout(now).expect("server timeout");
let client_out = drain_outputs(&mut client);
let server_out = drain_outputs(&mut server);
client_connected |= client_out.connected;
server_connected |= server_out.connected;
deliver_packets(&client_out.packets, &mut server);
deliver_packets(&server_out.packets, &mut client);
if client_connected && server_connected {
break;
}
now += Duration::from_millis(50);
}
assert!(client_connected, "Client should connect");
assert!(server_connected, "Server should connect");
let mut server_received = 0;
let mut client_received = 0;
for i in 0..100 {
let msg = format!("Client msg {}", i);
client
.send_application_data(msg.as_bytes())
.expect("client send");
now += Duration::from_millis(10);
for _ in 0..3 {
client.handle_timeout(now).expect("client timeout");
let client_out = drain_outputs(&mut client);
deliver_packets(&client_out.packets, &mut server);
server.handle_timeout(now).expect("server timeout");
let server_out = drain_outputs(&mut server);
deliver_packets(&server_out.packets, &mut client);
server_received += server_out.app_data.len();
}
}
for i in 0..100 {
let msg = format!("Server msg {}", i);
server
.send_application_data(msg.as_bytes())
.expect("server send");
now += Duration::from_millis(10);
for _ in 0..3 {
server.handle_timeout(now).expect("server timeout");
let server_out = drain_outputs(&mut server);
deliver_packets(&server_out.packets, &mut client);
client.handle_timeout(now).expect("client timeout");
let client_out = drain_outputs(&mut client);
deliver_packets(&client_out.packets, &mut server);
client_received += client_out.app_data.len();
}
}
assert_eq!(
server_received, 100,
"Server should receive all messages (proves KeyUpdate worked for client→server)"
);
assert_eq!(
client_received, 100,
"Client should receive all messages (proves KeyUpdate worked for server→client)"
);
}
#[test]
#[cfg(feature = "rcgen")]
fn dtls13_key_update_old_epoch_packet_still_decrypted() {
use dimpl::certificate::generate_self_signed_certificate;
let _ = env_logger::try_init();
let client_cert = generate_self_signed_certificate().expect("gen client cert");
let server_cert = generate_self_signed_certificate().expect("gen server cert");
let config = Arc::new(
Config::builder()
.aead_encryption_limit(10)
.build()
.expect("build config"),
);
let mut now = Instant::now();
let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now);
client.set_active(true);
let mut server = Dtls::new_13(config, server_cert, now);
server.set_active(false);
let mut client_connected = false;
let mut server_connected = false;
for _ in 0..30 {
client.handle_timeout(now).expect("client timeout");
server.handle_timeout(now).expect("server timeout");
let client_out = drain_outputs(&mut client);
let server_out = drain_outputs(&mut server);
client_connected |= client_out.connected;
server_connected |= server_out.connected;
deliver_packets(&client_out.packets, &mut server);
deliver_packets(&server_out.packets, &mut client);
if client_connected && server_connected {
break;
}
now += Duration::from_millis(50);
}
assert!(client_connected, "Client should connect");
assert!(server_connected, "Server should connect");
client
.send_application_data(b"delayed-old-epoch")
.expect("send delayed msg");
now += Duration::from_millis(10);
client.handle_timeout(now).expect("client timeout");
let client_out = drain_outputs(&mut client);
let delayed_packets = client_out.packets.clone();
server.handle_timeout(now).expect("server timeout");
let server_out = drain_outputs(&mut server);
deliver_packets(&server_out.packets, &mut client);
let mut server_received = 0;
for i in 0..15 {
let msg = format!("Message {}", i);
client
.send_application_data(msg.as_bytes())
.expect("send app data");
now += Duration::from_millis(10);
for _ in 0..3 {
client.handle_timeout(now).expect("client timeout");
let client_out = drain_outputs(&mut client);
deliver_packets(&client_out.packets, &mut server);
server.handle_timeout(now).expect("server timeout");
let server_out = drain_outputs(&mut server);
deliver_packets(&server_out.packets, &mut client);
server_received += server_out.app_data.len();
}
}
assert_eq!(
server_received, 15,
"All regular messages should be received"
);
deliver_packets(&delayed_packets, &mut server);
now += Duration::from_millis(10);
server.handle_timeout(now).expect("server timeout");
let server_out = drain_outputs(&mut server);
deliver_packets(&server_out.packets, &mut client);
let mut post_received = 0;
for i in 0..10 {
let msg = format!("Post msg {}", i);
client
.send_application_data(msg.as_bytes())
.expect("send post msg");
now += Duration::from_millis(10);
for _ in 0..3 {
client.handle_timeout(now).expect("client timeout");
let client_out = drain_outputs(&mut client);
deliver_packets(&client_out.packets, &mut server);
server.handle_timeout(now).expect("server timeout");
let server_out = drain_outputs(&mut server);
deliver_packets(&server_out.packets, &mut client);
post_received += server_out.app_data.len();
}
}
assert_eq!(
post_received, 10,
"All post-KeyUpdate messages should be received (stale packet didn't break connection)"
);
}
#[test]
#[cfg(feature = "rcgen")]
fn dtls13_key_update_multiple_sequential() {
use dimpl::certificate::generate_self_signed_certificate;
let _ = env_logger::try_init();
let client_cert = generate_self_signed_certificate().expect("gen client cert");
let server_cert = generate_self_signed_certificate().expect("gen server cert");
let config = Arc::new(
Config::builder()
.aead_encryption_limit(3)
.build()
.expect("build config"),
);
let mut now = Instant::now();
let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now);
client.set_active(true);
let mut server = Dtls::new_13(config, server_cert, now);
server.set_active(false);
let mut client_connected = false;
let mut server_connected = false;
for _ in 0..30 {
client.handle_timeout(now).expect("client timeout");
server.handle_timeout(now).expect("server timeout");
let client_out = drain_outputs(&mut client);
let server_out = drain_outputs(&mut server);
client_connected |= client_out.connected;
server_connected |= server_out.connected;
deliver_packets(&client_out.packets, &mut server);
deliver_packets(&server_out.packets, &mut client);
if client_connected && server_connected {
break;
}
now += Duration::from_millis(50);
}
assert!(client_connected, "Client should connect");
assert!(server_connected, "Server should connect");
let mut server_received = 0;
for i in 0..30 {
let msg = format!("Message {}", i);
client
.send_application_data(msg.as_bytes())
.expect("send app data");
now += Duration::from_millis(10);
for _ in 0..5 {
client.handle_timeout(now).expect("client timeout");
let client_out = drain_outputs(&mut client);
deliver_packets(&client_out.packets, &mut server);
server.handle_timeout(now).expect("server timeout");
let server_out = drain_outputs(&mut server);
deliver_packets(&server_out.packets, &mut client);
server_received += server_out.app_data.len();
}
}
assert_eq!(
server_received, 30,
"All 30 messages should be received across 3+ KeyUpdates"
);
}
#[test]
#[cfg(feature = "rcgen")]
fn dtls13_key_update_with_packet_loss() {
use dimpl::certificate::generate_self_signed_certificate;
let _ = env_logger::try_init();
let client_cert = generate_self_signed_certificate().expect("gen client cert");
let server_cert = generate_self_signed_certificate().expect("gen server cert");
let config = Arc::new(
Config::builder()
.aead_encryption_limit(5)
.build()
.expect("build config"),
);
let mut now = Instant::now();
let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now);
client.set_active(true);
let mut server = Dtls::new_13(config, server_cert, now);
server.set_active(false);
let mut client_connected = false;
let mut server_connected = false;
for _ in 0..30 {
client.handle_timeout(now).expect("client timeout");
server.handle_timeout(now).expect("server timeout");
let client_out = drain_outputs(&mut client);
let server_out = drain_outputs(&mut server);
client_connected |= client_out.connected;
server_connected |= server_out.connected;
deliver_packets(&client_out.packets, &mut server);
deliver_packets(&server_out.packets, &mut client);
if client_connected && server_connected {
break;
}
now += Duration::from_millis(50);
}
assert!(client_connected, "Client should connect");
assert!(server_connected, "Server should connect");
let mut _server_received = 0;
let mut dropped_round = false;
for i in 0..10 {
let msg = format!("Message {}", i);
client
.send_application_data(msg.as_bytes())
.expect("send app data");
now += Duration::from_millis(10);
client.handle_timeout(now).expect("client timeout");
let client_out = drain_outputs(&mut client);
if i == 5 && !dropped_round {
dropped_round = true;
} else {
deliver_packets(&client_out.packets, &mut server);
}
server.handle_timeout(now).expect("server timeout");
let server_out = drain_outputs(&mut server);
deliver_packets(&server_out.packets, &mut client);
_server_received += server_out.app_data.len();
}
for _ in 0..10 {
trigger_timeout(&mut client, &mut now);
let client_out = drain_outputs(&mut client);
deliver_packets(&client_out.packets, &mut server);
server.handle_timeout(now).expect("server timeout");
let server_out = drain_outputs(&mut server);
deliver_packets(&server_out.packets, &mut client);
_server_received += server_out.app_data.len();
}
let mut post_recovery_received = 0;
for i in 0..10 {
let msg = format!("Post-recovery {}", i);
client
.send_application_data(msg.as_bytes())
.expect("send post-recovery");
now += Duration::from_millis(10);
for _ in 0..3 {
client.handle_timeout(now).expect("client timeout");
let client_out = drain_outputs(&mut client);
deliver_packets(&client_out.packets, &mut server);
server.handle_timeout(now).expect("server timeout");
let server_out = drain_outputs(&mut server);
deliver_packets(&server_out.packets, &mut client);
post_recovery_received += server_out.app_data.len();
}
}
assert_eq!(
post_recovery_received, 10,
"All post-recovery messages should be received"
);
}
#[test]
#[cfg(feature = "rcgen")]
fn dtls13_key_update_high_frequency() {
use dimpl::certificate::generate_self_signed_certificate;
let _ = env_logger::try_init();
let client_cert = generate_self_signed_certificate().expect("gen client cert");
let server_cert = generate_self_signed_certificate().expect("gen server cert");
let config = Arc::new(
Config::builder()
.aead_encryption_limit(2)
.build()
.expect("build config"),
);
let mut now = Instant::now();
let mut client = Dtls::new_13(Arc::clone(&config), client_cert, now);
client.set_active(true);
let mut server = Dtls::new_13(config, server_cert, now);
server.set_active(false);
let mut client_connected = false;
let mut server_connected = false;
for _ in 0..30 {
client.handle_timeout(now).expect("client timeout");
server.handle_timeout(now).expect("server timeout");
let client_out = drain_outputs(&mut client);
let server_out = drain_outputs(&mut server);
client_connected |= client_out.connected;
server_connected |= server_out.connected;
deliver_packets(&client_out.packets, &mut server);
deliver_packets(&server_out.packets, &mut client);
if client_connected && server_connected {
break;
}
now += Duration::from_millis(50);
}
assert!(client_connected, "Client should connect");
assert!(server_connected, "Server should connect");
let mut server_received = 0;
for i in 0..50 {
let msg = format!("High-freq msg {}", i);
client
.send_application_data(msg.as_bytes())
.expect("send app data");
now += Duration::from_millis(10);
for _ in 0..5 {
client.handle_timeout(now).expect("client timeout");
let client_out = drain_outputs(&mut client);
deliver_packets(&client_out.packets, &mut server);
server.handle_timeout(now).expect("server timeout");
let server_out = drain_outputs(&mut server);
deliver_packets(&server_out.packets, &mut client);
server_received += server_out.app_data.len();
}
}
assert_eq!(
server_received, 50,
"All 50 messages should be received despite high-frequency KeyUpdates"
);
}