use super::*;
use rand::Rng;
use secp256k1::Parity;
fn generate_keypair() -> secp256k1::Keypair {
let secp = secp256k1::Secp256k1::new();
let mut secret_bytes = [0u8; 32];
rand::rng().fill_bytes(&mut secret_bytes);
let secret_key = secp256k1::SecretKey::from_slice(&secret_bytes)
.expect("32 random bytes is a valid secret key");
secp256k1::Keypair::from_secret_key(&secp, &secret_key)
}
fn generate_epoch() -> [u8; 8] {
let mut epoch = [0u8; 8];
rand::rng().fill_bytes(&mut epoch);
epoch
}
#[test]
fn test_full_handshake() {
let initiator_keypair = generate_keypair();
let responder_keypair = generate_keypair();
let initiator_epoch = generate_epoch();
let responder_epoch = generate_epoch();
let responder_pub = responder_keypair.public_key();
let mut initiator = HandshakeState::new_initiator(initiator_keypair, responder_pub);
initiator.set_local_epoch(initiator_epoch);
let mut responder = HandshakeState::new_responder(responder_keypair);
responder.set_local_epoch(responder_epoch);
assert_eq!(initiator.role(), HandshakeRole::Initiator);
assert_eq!(responder.role(), HandshakeRole::Responder);
assert!(responder.remote_static().is_none());
let msg1 = initiator.write_message_1().unwrap();
assert_eq!(msg1.len(), HANDSHAKE_MSG1_SIZE);
responder.read_message_1(&msg1).unwrap();
assert!(responder.remote_static().is_some());
assert_eq!(
responder.remote_static().unwrap(),
&initiator_keypair.public_key()
);
assert_eq!(responder.remote_epoch(), Some(initiator_epoch));
let msg2 = responder.write_message_2().unwrap();
assert_eq!(msg2.len(), HANDSHAKE_MSG2_SIZE);
initiator.read_message_2(&msg2).unwrap();
assert!(initiator.is_complete());
assert!(responder.is_complete());
assert_eq!(initiator.remote_epoch(), Some(responder_epoch));
assert_eq!(initiator.handshake_hash(), responder.handshake_hash());
let mut initiator_session = initiator.into_session().unwrap();
let mut responder_session = responder.into_session().unwrap();
let plaintext = b"Hello, secure world!";
let ciphertext = initiator_session.encrypt(plaintext).unwrap();
let decrypted = responder_session.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
let plaintext2 = b"Hello back!";
let ciphertext2 = responder_session.encrypt(plaintext2).unwrap();
let decrypted2 = initiator_session.decrypt(&ciphertext2).unwrap();
assert_eq!(decrypted2, plaintext2);
}
#[test]
fn test_multiple_messages() {
let initiator_keypair = generate_keypair();
let responder_keypair = generate_keypair();
let mut initiator =
HandshakeState::new_initiator(initiator_keypair, responder_keypair.public_key());
initiator.set_local_epoch(generate_epoch());
let mut responder = HandshakeState::new_responder(responder_keypair);
responder.set_local_epoch(generate_epoch());
let msg1 = initiator.write_message_1().unwrap();
responder.read_message_1(&msg1).unwrap();
let msg2 = responder.write_message_2().unwrap();
initiator.read_message_2(&msg2).unwrap();
let mut initiator_session = initiator.into_session().unwrap();
let mut responder_session = responder.into_session().unwrap();
for i in 0..100 {
let msg = format!("Message {}", i);
let ct = initiator_session.encrypt(msg.as_bytes()).unwrap();
let pt = responder_session.decrypt(&ct).unwrap();
assert_eq!(pt, msg.as_bytes());
}
assert_eq!(initiator_session.send_nonce(), 100);
assert_eq!(responder_session.recv_nonce(), 100);
}
#[test]
fn test_wrong_role_errors() {
let keypair1 = generate_keypair();
let keypair2 = generate_keypair();
let mut initiator = HandshakeState::new_initiator(keypair1, keypair2.public_key());
initiator.set_local_epoch(generate_epoch());
assert!(
initiator
.read_message_1(&[0u8; HANDSHAKE_MSG1_SIZE])
.is_err()
);
assert!(initiator.write_message_2().is_err());
}
#[test]
fn test_invalid_pubkey_in_msg1() {
let keypair = generate_keypair();
let mut responder = HandshakeState::new_responder(keypair);
responder.set_local_epoch(generate_epoch());
let invalid_msg = [0u8; HANDSHAKE_MSG1_SIZE];
assert!(responder.read_message_1(&invalid_msg).is_err());
}
#[test]
fn test_decryption_failure_wrong_key() {
let keypair1 = generate_keypair();
let keypair2 = generate_keypair();
let keypair3 = generate_keypair();
let mut init1 = HandshakeState::new_initiator(keypair1, keypair2.public_key());
init1.set_local_epoch(generate_epoch());
let mut resp1 = HandshakeState::new_responder(keypair2);
resp1.set_local_epoch(generate_epoch());
let msg1 = init1.write_message_1().unwrap();
resp1.read_message_1(&msg1).unwrap();
let msg2 = resp1.write_message_2().unwrap();
init1.read_message_2(&msg2).unwrap();
let mut session1 = init1.into_session().unwrap();
let mut init2 = HandshakeState::new_initiator(keypair1, keypair3.public_key());
init2.set_local_epoch(generate_epoch());
let mut resp2 = HandshakeState::new_responder(keypair3);
resp2.set_local_epoch(generate_epoch());
let msg1 = init2.write_message_1().unwrap();
resp2.read_message_1(&msg1).unwrap();
let msg2 = resp2.write_message_2().unwrap();
init2.read_message_2(&msg2).unwrap();
let mut session2 = resp2.into_session().unwrap();
let ciphertext = session1.encrypt(b"test").unwrap();
assert!(session2.decrypt(&ciphertext).is_err());
}
#[test]
fn test_cipher_state_nonce_sequence() {
let key = [0u8; 32];
let mut cipher = CipherState::new(key);
assert_eq!(cipher.nonce(), 0);
let _ = cipher.encrypt(b"test").unwrap();
assert_eq!(cipher.nonce(), 1);
let _ = cipher.encrypt(b"test").unwrap();
assert_eq!(cipher.nonce(), 2);
}
#[test]
fn test_session_remote_static() {
let keypair1 = generate_keypair();
let keypair2 = generate_keypair();
let mut init = HandshakeState::new_initiator(keypair1, keypair2.public_key());
init.set_local_epoch(generate_epoch());
let mut resp = HandshakeState::new_responder(keypair2);
resp.set_local_epoch(generate_epoch());
let msg1 = init.write_message_1().unwrap();
resp.read_message_1(&msg1).unwrap();
let msg2 = resp.write_message_2().unwrap();
init.read_message_2(&msg2).unwrap();
let session1 = init.into_session().unwrap();
let session2 = resp.into_session().unwrap();
assert_eq!(session1.remote_static(), &keypair2.public_key());
assert_eq!(session2.remote_static(), &keypair1.public_key());
}
#[test]
fn test_message_sizes() {
assert_eq!(EPOCH_SIZE, 8);
assert_eq!(EPOCH_ENCRYPTED_SIZE, 8 + 16); assert_eq!(HANDSHAKE_MSG1_SIZE, 33 + 33 + 16 + 24); assert_eq!(HANDSHAKE_MSG2_SIZE, 33 + 24); }
#[test]
fn test_responder_identity_discovery() {
let initiator_keypair = generate_keypair();
let responder_keypair = generate_keypair();
let mut responder = HandshakeState::new_responder(responder_keypair);
responder.set_local_epoch(generate_epoch());
assert!(responder.remote_static().is_none());
let mut initiator =
HandshakeState::new_initiator(initiator_keypair, responder_keypair.public_key());
initiator.set_local_epoch(generate_epoch());
let msg1 = initiator.write_message_1().unwrap();
responder.read_message_1(&msg1).unwrap();
let discovered_initiator = responder.remote_static().unwrap();
assert_eq!(discovered_initiator, &initiator_keypair.public_key());
}
#[test]
fn test_replay_window_basic() {
let mut window = ReplayWindow::new();
assert!(window.check(0));
window.accept(0);
assert_eq!(window.highest(), 0);
assert!(!window.check(0));
assert!(window.check(1));
window.accept(1);
assert_eq!(window.highest(), 1);
window.accept(10);
assert!(window.check(5));
window.accept(5);
assert!(!window.check(5));
}
#[test]
fn test_replay_window_large_jump() {
let mut window = ReplayWindow::new();
window.accept(0);
window.accept(REPLAY_WINDOW_SIZE as u64 + 100);
assert!(!window.check(0));
assert!(!window.check(50));
assert!(window.check(REPLAY_WINDOW_SIZE as u64 + 99));
assert!(window.check(REPLAY_WINDOW_SIZE as u64 + 50));
}
#[test]
fn test_replay_window_boundary() {
let mut window = ReplayWindow::new();
window.accept(REPLAY_WINDOW_SIZE as u64 - 1);
assert!(window.check(0));
window.accept(0);
window.accept(REPLAY_WINDOW_SIZE as u64);
assert!(!window.check(0));
assert!(window.check(1));
}
#[test]
fn test_replay_window_sequential() {
let mut window = ReplayWindow::new();
for i in 0..1000 {
assert!(window.check(i), "Counter {} should be acceptable", i);
window.accept(i);
}
for i in 0..1000 {
assert!(
!window.check(i),
"Counter {} should be rejected as replay",
i
);
}
assert_eq!(window.highest(), 999);
}
#[test]
fn test_replay_window_reset() {
let mut window = ReplayWindow::new();
window.accept(100);
assert_eq!(window.highest(), 100);
assert!(!window.check(100));
window.reset();
assert_eq!(window.highest(), 0);
assert!(window.check(100));
}
#[test]
fn test_session_replay_protection() {
let keypair1 = generate_keypair();
let keypair2 = generate_keypair();
let mut init = HandshakeState::new_initiator(keypair1, keypair2.public_key());
init.set_local_epoch(generate_epoch());
let mut resp = HandshakeState::new_responder(keypair2);
resp.set_local_epoch(generate_epoch());
let msg1 = init.write_message_1().unwrap();
resp.read_message_1(&msg1).unwrap();
let msg2 = resp.write_message_2().unwrap();
init.read_message_2(&msg2).unwrap();
let mut sender = init.into_session().unwrap();
let mut receiver = resp.into_session().unwrap();
let counter = sender.current_send_counter();
let ciphertext = sender.encrypt(b"test message").unwrap();
let plaintext = receiver
.decrypt_with_replay_check(&ciphertext, counter)
.unwrap();
assert_eq!(plaintext, b"test message");
let result = receiver.decrypt_with_replay_check(&ciphertext, counter);
assert!(matches!(result, Err(NoiseError::ReplayDetected(_))));
assert!(receiver.check_replay(counter).is_err());
}
#[test]
fn test_handshake_with_odd_parity_responder() {
let secp = secp256k1::Secp256k1::new();
let sk_b = secp256k1::SecretKey::from_slice(
&hex::decode("b102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1fb0").unwrap(),
)
.unwrap();
let kp_b = secp256k1::Keypair::from_secret_key(&secp, &sk_b);
let (xonly_b, parity_b) = kp_b.public_key().x_only_public_key();
assert_eq!(
parity_b,
Parity::Odd,
"Test requires odd-parity responder key"
);
let sk_a = secp256k1::SecretKey::from_slice(
&hex::decode("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20").unwrap(),
)
.unwrap();
let kp_a = secp256k1::Keypair::from_secret_key(&secp, &sk_a);
let assumed_even_b = xonly_b.public_key(Parity::Even);
assert_ne!(
assumed_even_b,
kp_b.public_key(),
"Even assumption should differ from actual odd key"
);
let mut initiator = HandshakeState::new_initiator(kp_a, assumed_even_b);
initiator.set_local_epoch(generate_epoch());
let mut responder = HandshakeState::new_responder(kp_b);
responder.set_local_epoch(generate_epoch());
let msg1 = initiator.write_message_1().unwrap();
responder.read_message_1(&msg1).unwrap();
let msg2 = responder.write_message_2().unwrap();
initiator.read_message_2(&msg2).unwrap();
assert!(initiator.is_complete());
assert!(responder.is_complete());
let mut sender = initiator.into_session().unwrap();
let mut receiver = responder.into_session().unwrap();
let counter = sender.current_send_counter();
let ciphertext = sender.encrypt(b"parity test").unwrap();
let plaintext = receiver
.decrypt_with_replay_check(&ciphertext, counter)
.unwrap();
assert_eq!(plaintext, b"parity test");
}
#[test]
fn test_xk_full_handshake() {
let initiator_keypair = generate_keypair();
let responder_keypair = generate_keypair();
let initiator_epoch = generate_epoch();
let responder_epoch = generate_epoch();
let responder_pub = responder_keypair.public_key();
let mut initiator = HandshakeState::new_xk_initiator(initiator_keypair, responder_pub);
initiator.set_local_epoch(initiator_epoch);
let mut responder = HandshakeState::new_xk_responder(responder_keypair);
responder.set_local_epoch(responder_epoch);
assert_eq!(initiator.role(), HandshakeRole::Initiator);
assert_eq!(responder.role(), HandshakeRole::Responder);
assert!(responder.remote_static().is_none());
let msg1 = initiator.write_xk_message_1().unwrap();
assert_eq!(msg1.len(), XK_HANDSHAKE_MSG1_SIZE);
assert_eq!(msg1.len(), 33);
responder.read_xk_message_1(&msg1).unwrap();
assert!(responder.remote_static().is_none());
assert!(responder.remote_epoch().is_none());
let msg2 = responder.write_xk_message_2().unwrap();
assert_eq!(msg2.len(), XK_HANDSHAKE_MSG2_SIZE);
assert_eq!(msg2.len(), 57);
initiator.read_xk_message_2(&msg2).unwrap();
assert_eq!(initiator.remote_epoch(), Some(responder_epoch));
assert!(!initiator.is_complete());
assert!(!responder.is_complete());
let msg3 = initiator.write_xk_message_3().unwrap();
assert_eq!(msg3.len(), XK_HANDSHAKE_MSG3_SIZE);
assert_eq!(msg3.len(), 73);
responder.read_xk_message_3(&msg3).unwrap();
assert!(initiator.is_complete());
assert!(responder.is_complete());
assert!(responder.remote_static().is_some());
assert_eq!(
responder.remote_static().unwrap(),
&initiator_keypair.public_key()
);
assert_eq!(responder.remote_epoch(), Some(initiator_epoch));
assert_eq!(initiator.handshake_hash(), responder.handshake_hash());
let mut initiator_session = initiator.into_session().unwrap();
let mut responder_session = responder.into_session().unwrap();
let plaintext = b"Hello via XK!";
let ciphertext = initiator_session.encrypt(plaintext).unwrap();
let decrypted = responder_session.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
let plaintext2 = b"XK reply!";
let ciphertext2 = responder_session.encrypt(plaintext2).unwrap();
let decrypted2 = initiator_session.decrypt(&ciphertext2).unwrap();
assert_eq!(decrypted2, plaintext2);
}
#[test]
fn test_xk_message_sizes() {
assert_eq!(XK_HANDSHAKE_MSG1_SIZE, 33); assert_eq!(XK_HANDSHAKE_MSG2_SIZE, 33 + 24); assert_eq!(XK_HANDSHAKE_MSG3_SIZE, 33 + 16 + 24); }
#[test]
fn test_xk_identity_timing() {
let initiator_keypair = generate_keypair();
let responder_keypair = generate_keypair();
let mut initiator =
HandshakeState::new_xk_initiator(initiator_keypair, responder_keypair.public_key());
initiator.set_local_epoch(generate_epoch());
let mut responder = HandshakeState::new_xk_responder(responder_keypair);
responder.set_local_epoch(generate_epoch());
assert!(responder.remote_static().is_none());
let msg1 = initiator.write_xk_message_1().unwrap();
responder.read_xk_message_1(&msg1).unwrap();
assert!(
responder.remote_static().is_none(),
"XK: responder should NOT know identity after msg1"
);
let msg2 = responder.write_xk_message_2().unwrap();
initiator.read_xk_message_2(&msg2).unwrap();
assert!(
responder.remote_static().is_none(),
"XK: responder should NOT know identity after msg2"
);
let msg3 = initiator.write_xk_message_3().unwrap();
responder.read_xk_message_3(&msg3).unwrap();
assert!(
responder.remote_static().is_some(),
"XK: responder should know identity after msg3"
);
assert_eq!(
responder.remote_static().unwrap(),
&initiator_keypair.public_key()
);
}
#[test]
fn test_xk_wrong_state_errors() {
let keypair1 = generate_keypair();
let keypair2 = generate_keypair();
let mut initiator = HandshakeState::new_xk_initiator(keypair1, keypair2.public_key());
initiator.set_local_epoch(generate_epoch());
assert!(
initiator
.read_xk_message_1(&[0u8; XK_HANDSHAKE_MSG1_SIZE])
.is_err()
);
assert!(initiator.write_xk_message_2().is_err());
assert!(initiator.write_xk_message_3().is_err());
let mut responder = HandshakeState::new_xk_responder(keypair2);
responder.set_local_epoch(generate_epoch());
assert!(responder.write_xk_message_1().is_err());
assert!(
responder
.read_xk_message_3(&[0u8; XK_HANDSHAKE_MSG3_SIZE])
.is_err()
);
}
#[test]
fn test_xk_handshake_hash_differs_from_ik() {
let keypair1 = generate_keypair();
let keypair2 = generate_keypair();
let epoch1 = generate_epoch();
let epoch2 = generate_epoch();
let mut ik_init = HandshakeState::new_initiator(keypair1, keypair2.public_key());
ik_init.set_local_epoch(epoch1);
let mut ik_resp = HandshakeState::new_responder(keypair2);
ik_resp.set_local_epoch(epoch2);
let msg1 = ik_init.write_message_1().unwrap();
ik_resp.read_message_1(&msg1).unwrap();
let msg2 = ik_resp.write_message_2().unwrap();
ik_init.read_message_2(&msg2).unwrap();
let ik_hash = ik_init.handshake_hash();
let mut xk_init = HandshakeState::new_xk_initiator(keypair1, keypair2.public_key());
xk_init.set_local_epoch(epoch1);
let mut xk_resp = HandshakeState::new_xk_responder(keypair2);
xk_resp.set_local_epoch(epoch2);
let msg1 = xk_init.write_xk_message_1().unwrap();
xk_resp.read_xk_message_1(&msg1).unwrap();
let msg2 = xk_resp.write_xk_message_2().unwrap();
xk_init.read_xk_message_2(&msg2).unwrap();
let msg3 = xk_init.write_xk_message_3().unwrap();
xk_resp.read_xk_message_3(&msg3).unwrap();
let xk_hash = xk_init.handshake_hash();
assert_ne!(
ik_hash, xk_hash,
"IK and XK should produce different handshake hashes"
);
}
#[test]
fn test_xk_multiple_messages_after_handshake() {
let keypair1 = generate_keypair();
let keypair2 = generate_keypair();
let mut initiator = HandshakeState::new_xk_initiator(keypair1, keypair2.public_key());
initiator.set_local_epoch(generate_epoch());
let mut responder = HandshakeState::new_xk_responder(keypair2);
responder.set_local_epoch(generate_epoch());
let msg1 = initiator.write_xk_message_1().unwrap();
responder.read_xk_message_1(&msg1).unwrap();
let msg2 = responder.write_xk_message_2().unwrap();
initiator.read_xk_message_2(&msg2).unwrap();
let msg3 = initiator.write_xk_message_3().unwrap();
responder.read_xk_message_3(&msg3).unwrap();
let mut init_session = initiator.into_session().unwrap();
let mut resp_session = responder.into_session().unwrap();
for i in 0..100 {
let msg = format!("XK message {}", i);
let ct = init_session.encrypt(msg.as_bytes()).unwrap();
let pt = resp_session.decrypt(&ct).unwrap();
assert_eq!(pt, msg.as_bytes());
}
assert_eq!(init_session.send_nonce(), 100);
assert_eq!(resp_session.recv_nonce(), 100);
}
#[test]
fn test_xk_with_odd_parity_responder() {
let secp = secp256k1::Secp256k1::new();
let sk_b = secp256k1::SecretKey::from_slice(
&hex::decode("b102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1fb0").unwrap(),
)
.unwrap();
let kp_b = secp256k1::Keypair::from_secret_key(&secp, &sk_b);
let (xonly_b, parity_b) = kp_b.public_key().x_only_public_key();
assert_eq!(
parity_b,
Parity::Odd,
"Test requires odd-parity responder key"
);
let sk_a = secp256k1::SecretKey::from_slice(
&hex::decode("0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20").unwrap(),
)
.unwrap();
let kp_a = secp256k1::Keypair::from_secret_key(&secp, &sk_a);
let assumed_even_b = xonly_b.public_key(Parity::Even);
let mut initiator = HandshakeState::new_xk_initiator(kp_a, assumed_even_b);
initiator.set_local_epoch(generate_epoch());
let mut responder = HandshakeState::new_xk_responder(kp_b);
responder.set_local_epoch(generate_epoch());
let msg1 = initiator.write_xk_message_1().unwrap();
responder.read_xk_message_1(&msg1).unwrap();
let msg2 = responder.write_xk_message_2().unwrap();
initiator.read_xk_message_2(&msg2).unwrap();
let msg3 = initiator.write_xk_message_3().unwrap();
responder.read_xk_message_3(&msg3).unwrap();
assert!(initiator.is_complete());
assert!(responder.is_complete());
let mut sender = initiator.into_session().unwrap();
let mut receiver = responder.into_session().unwrap();
let counter = sender.current_send_counter();
let ciphertext = sender.encrypt(b"xk parity test").unwrap();
let plaintext = receiver
.decrypt_with_replay_check(&ciphertext, counter)
.unwrap();
assert_eq!(plaintext, b"xk parity test");
}
#[test]
fn test_xk_invalid_msg1_size() {
let keypair = generate_keypair();
let mut responder = HandshakeState::new_xk_responder(keypair);
responder.set_local_epoch(generate_epoch());
assert!(
responder
.read_xk_message_1(&[0u8; HANDSHAKE_MSG1_SIZE])
.is_err()
);
assert!(responder.read_xk_message_1(&[0u8; 10]).is_err());
}
#[test]
fn test_xk_invalid_msg3_size() {
let keypair1 = generate_keypair();
let keypair2 = generate_keypair();
let mut initiator = HandshakeState::new_xk_initiator(keypair1, keypair2.public_key());
initiator.set_local_epoch(generate_epoch());
let mut responder = HandshakeState::new_xk_responder(keypair2);
responder.set_local_epoch(generate_epoch());
let msg1 = initiator.write_xk_message_1().unwrap();
responder.read_xk_message_1(&msg1).unwrap();
let _msg2 = responder.write_xk_message_2().unwrap();
assert!(responder.read_xk_message_3(&[0u8; 10]).is_err());
assert!(
responder
.read_xk_message_3(&[0u8; XK_HANDSHAKE_MSG3_SIZE + 1])
.is_err()
);
}
#[test]
fn test_encrypt_with_counter_no_aad_roundtrip() {
let keypair1 = generate_keypair();
let keypair2 = generate_keypair();
let mut init = HandshakeState::new_initiator(keypair1, keypair2.public_key());
init.set_local_epoch(generate_epoch());
let mut resp = HandshakeState::new_responder(keypair2);
resp.set_local_epoch(generate_epoch());
let msg1 = init.write_message_1().unwrap();
resp.read_message_1(&msg1).unwrap();
let msg2 = resp.write_message_2().unwrap();
init.read_message_2(&msg2).unwrap();
let sender = init.into_session().unwrap();
let mut receiver = resp.into_session().unwrap();
let send_cipher = sender.send_cipher_clone().unwrap();
let counter = 0u64;
let plaintext = b"off-task encrypt";
let nonce = CipherState::counter_to_nonce(counter);
let mut buf = plaintext.to_vec();
send_cipher
.seal_in_place_append_tag(nonce, ring::aead::Aad::empty(), &mut buf)
.expect("worker AEAD encrypt");
let ciphertext = buf;
let decrypted = receiver
.decrypt_with_replay_check(&ciphertext, counter)
.unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_encrypt_with_counter_matches_internal_counter() {
let key = [0x42u8; 32];
let mut a = CipherState::new(key);
let b = CipherState::new(key);
let plaintext = b"same key, same counter, same output";
let counter_a = a.nonce();
let ct_a = a.encrypt(plaintext).unwrap();
let ct_b = b.encrypt_with_counter(plaintext, counter_a).unwrap();
assert_eq!(
ct_a, ct_b,
"explicit-counter encrypt must be byte-identical"
);
assert_eq!(b.nonce(), 0);
}
#[test]
fn test_encrypt_with_counter_and_aad_roundtrip_via_session() {
let keypair1 = generate_keypair();
let keypair2 = generate_keypair();
let mut init = HandshakeState::new_initiator(keypair1, keypair2.public_key());
init.set_local_epoch(generate_epoch());
let mut resp = HandshakeState::new_responder(keypair2);
resp.set_local_epoch(generate_epoch());
let msg1 = init.write_message_1().unwrap();
resp.read_message_1(&msg1).unwrap();
let msg2 = resp.write_message_2().unwrap();
init.read_message_2(&msg2).unwrap();
let mut sender = init.into_session().unwrap();
let mut receiver = resp.into_session().unwrap();
let aad = b"outer header bytes";
let plaintext = b"pipelined send";
let counter = sender.take_send_counter().unwrap();
assert_eq!(counter, 0);
assert_eq!(sender.send_nonce(), 1, "counter reserved → nonce advanced");
let cipher = sender.send_cipher_clone().unwrap();
let nonce = CipherState::counter_to_nonce(counter);
let mut buf = plaintext.to_vec();
cipher
.seal_in_place_append_tag(nonce, ring::aead::Aad::from(aad), &mut buf)
.unwrap();
let ciphertext = buf;
let decrypted = receiver
.decrypt_with_replay_check_and_aad(&ciphertext, counter, aad)
.unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_recv_cipher_clone_matches_decrypt_with_counter_and_aad() {
let keypair1 = generate_keypair();
let keypair2 = generate_keypair();
let mut init = HandshakeState::new_initiator(keypair1, keypair2.public_key());
init.set_local_epoch(generate_epoch());
let mut resp = HandshakeState::new_responder(keypair2);
resp.set_local_epoch(generate_epoch());
let msg1 = init.write_message_1().unwrap();
resp.read_message_1(&msg1).unwrap();
let msg2 = resp.write_message_2().unwrap();
init.read_message_2(&msg2).unwrap();
let mut sender = init.into_session().unwrap();
let mut receiver = resp.into_session().unwrap();
let aad = b"AAD-bound transport header";
let plaintext = b"off-task decrypt";
let counter = sender.current_send_counter();
let ciphertext = sender.encrypt_with_aad(plaintext, aad).unwrap();
assert!(receiver.check_replay(counter).is_ok());
let cipher = receiver.recv_cipher_clone().unwrap();
let nonce = CipherState::counter_to_nonce(counter);
let mut buf = ciphertext.clone();
let worker_plaintext = cipher
.open_in_place(nonce, ring::aead::Aad::from(aad), &mut buf)
.unwrap()
.to_vec();
assert_eq!(worker_plaintext, plaintext);
receiver.accept_replay(counter);
assert!(receiver.check_replay(counter).is_err());
}