use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use triglav::crypto::{KeyPair, NoiseSession, PublicKey, SecretKey};
use triglav::error::{CryptoError, Result};
use triglav::protocol::{Packet, PacketFlags, PacketType, HEADER_SIZE};
use triglav::types::{SequenceNumber, SessionId};
#[test]
fn test_zero_key_rejected() {
let zero_secret = SecretKey::from_bytes([0u8; 32]);
let server_keypair = KeyPair::generate();
let result = NoiseSession::new_initiator(&zero_secret, &server_keypair.public);
if let Ok(mut session) = result {
let hs = session.write_handshake(&[]);
assert!(
hs.is_ok(),
"Handshake write should succeed (key validation is application responsibility)"
);
}
}
#[test]
fn test_invalid_public_key_base64() {
let result = PublicKey::from_base64("not-valid-base64!!!");
assert!(result.is_err(), "Should reject invalid base64");
let result = PublicKey::from_base64("AAAA"); assert!(result.is_err(), "Should reject wrong length key");
}
#[test]
fn test_invalid_secret_key_base64() {
let result = SecretKey::from_base64("invalid!@#$%");
assert!(result.is_err(), "Should reject invalid base64");
let result = SecretKey::from_base64("AAAA"); assert!(result.is_err(), "Should reject wrong length key");
}
#[test]
fn test_key_roundtrip_integrity() {
let keypair = KeyPair::generate();
let encoded = keypair.public.to_base64();
let decoded = PublicKey::from_base64(&encoded).unwrap();
assert_eq!(keypair.public.as_bytes(), decoded.as_bytes());
let encoded_secret = keypair.secret.to_base64();
let decoded_secret = SecretKey::from_base64(&encoded_secret).unwrap();
assert_eq!(keypair.secret.as_bytes(), decoded_secret.as_bytes());
}
#[test]
fn test_wrong_server_key_handshake_fails() {
let client_keypair = KeyPair::generate();
let server_keypair = KeyPair::generate();
let wrong_server_keypair = KeyPair::generate();
let mut client = NoiseSession::new_initiator(
&client_keypair.secret,
&wrong_server_keypair.public, )
.unwrap();
let mut server = NoiseSession::new_responder(&server_keypair.secret).unwrap();
let msg1 = client.write_handshake(&[]).unwrap();
let server_read_result = server.read_handshake(&msg1);
if server_read_result.is_err() {
return;
}
let server_write_result = server.write_handshake(&[]);
if server_write_result.is_err() {
return;
}
let msg2 = server_write_result.unwrap();
let result = client.read_handshake(&msg2);
if result.is_ok() {
assert!(client.is_transport());
assert!(server.is_transport());
let plaintext = b"test message";
let ciphertext = client.encrypt(plaintext).unwrap();
let decrypt_result = server.decrypt(&ciphertext);
assert!(
decrypt_result.is_err(),
"Decryption should fail with wrong keys"
);
}
}
#[test]
fn test_truncated_handshake_message() {
let client_keypair = KeyPair::generate();
let server_keypair = KeyPair::generate();
let mut client =
NoiseSession::new_initiator(&client_keypair.secret, &server_keypair.public).unwrap();
let mut server = NoiseSession::new_responder(&server_keypair.secret).unwrap();
let msg1 = client.write_handshake(&[]).unwrap();
let truncated = &msg1[..msg1.len() / 2];
let result = server.read_handshake(truncated);
assert!(result.is_err(), "Should reject truncated handshake");
}
#[test]
fn test_corrupted_handshake_message() {
let client_keypair = KeyPair::generate();
let server_keypair = KeyPair::generate();
let mut client =
NoiseSession::new_initiator(&client_keypair.secret, &server_keypair.public).unwrap();
let mut server = NoiseSession::new_responder(&server_keypair.secret).unwrap();
let mut msg1 = client.write_handshake(&[]).unwrap();
for i in 0..msg1.len().min(10) {
msg1[i] ^= 0xFF;
}
let result = server.read_handshake(&msg1);
assert!(result.is_err(), "Should reject corrupted handshake");
}
fn complete_handshake(
client_secret: &SecretKey,
server_secret: &SecretKey,
server_public: &PublicKey,
) -> (NoiseSession, NoiseSession) {
let mut client = NoiseSession::new_initiator(client_secret, server_public).unwrap();
let mut server = NoiseSession::new_responder(server_secret).unwrap();
let msg1 = client.write_handshake(&[]).unwrap();
let _ = server.read_handshake(&msg1).unwrap();
let msg2 = server.write_handshake(&[]).unwrap();
let _ = client.read_handshake(&msg2).unwrap();
assert!(client.is_transport());
assert!(server.is_transport());
(client, server)
}
#[test]
fn test_ciphertext_bit_flip_detected() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client, mut server) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
let plaintext = b"sensitive data";
let mut ciphertext = client.encrypt(plaintext).unwrap();
ciphertext[0] ^= 0x01;
let result = server.decrypt(&ciphertext);
assert!(result.is_err(), "Should detect bit flip in ciphertext");
}
#[test]
fn test_ciphertext_truncation_detected() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client, mut server) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
let plaintext = b"sensitive data";
let ciphertext = client.encrypt(plaintext).unwrap();
let truncated = &ciphertext[..ciphertext.len() - 1];
let result = server.decrypt(truncated);
assert!(result.is_err(), "Should detect truncated ciphertext");
}
#[test]
fn test_ciphertext_extension_detected() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client, mut server) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
let plaintext = b"sensitive data";
let mut ciphertext = client.encrypt(plaintext).unwrap();
ciphertext.extend_from_slice(b"extra garbage");
let result = server.decrypt(&ciphertext);
if let Ok(decrypted) = result {
assert_eq!(
decrypted, plaintext,
"Extra data should not appear in plaintext"
);
}
}
#[test]
fn test_auth_tag_tampering_detected() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client, mut server) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
let plaintext = b"sensitive data";
let mut ciphertext = client.encrypt(plaintext).unwrap();
let tag_start = ciphertext.len() - 16;
for i in tag_start..ciphertext.len() {
ciphertext[i] ^= 0xFF;
}
let result = server.decrypt(&ciphertext);
assert!(result.is_err(), "Should detect auth tag tampering");
}
#[test]
fn test_empty_ciphertext_rejected() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (_, mut server) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
let result = server.decrypt(&[]);
assert!(result.is_err(), "Should reject empty ciphertext");
}
#[test]
fn test_short_ciphertext_rejected() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (_, mut server) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
let short = vec![0u8; 15];
let result = server.decrypt(&short);
assert!(
result.is_err(),
"Should reject ciphertext shorter than auth tag"
);
}
#[test]
fn test_replay_same_ciphertext() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client, mut server) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
let plaintext = b"original message";
let ciphertext = client.encrypt(plaintext).unwrap();
let decrypted = server.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
let result = server.decrypt(&ciphertext);
assert!(result.is_err(), "Replay of same ciphertext should fail");
}
#[test]
fn test_out_of_order_decryption() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client, mut server) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
let ct1 = client.encrypt(b"message 1").unwrap();
let ct2 = client.encrypt(b"message 2").unwrap();
let ct3 = client.encrypt(b"message 3").unwrap();
let r1 = server.decrypt(&ct1);
assert!(r1.is_ok());
let r3 = server.decrypt(&ct3);
assert!(r3.is_err(), "Out-of-order decryption should fail");
}
#[test]
fn test_cross_session_decryption_fails() {
let client1_kp = KeyPair::generate();
let client2_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client1, mut server1) =
complete_handshake(&client1_kp.secret, &server_kp.secret, &server_kp.public);
let (mut client2, mut server2) =
complete_handshake(&client2_kp.secret, &server_kp.secret, &server_kp.public);
let plaintext = b"secret from client1";
let ciphertext = client1.encrypt(plaintext).unwrap();
let decrypted = server1.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
let result = server2.decrypt(&ciphertext);
assert!(result.is_err(), "Cross-session decryption should fail");
}
#[test]
fn test_different_sessions_different_keys() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client1, mut server1) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
let (mut client2, mut server2) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
let plaintext = b"same plaintext";
let ct1 = client1.encrypt(plaintext).unwrap();
let ct2 = client2.encrypt(plaintext).unwrap();
assert_ne!(
ct1, ct2,
"Same plaintext should produce different ciphertexts in different sessions"
);
assert!(
server2.decrypt(&ct1).is_err(),
"Cross-session decryption should fail"
);
assert!(
server1.decrypt(&ct2).is_err(),
"Cross-session decryption should fail"
);
}
#[test]
fn test_nonce_increments() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client, _server) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
let initial_nonce = client.nonce_counter().unwrap();
for i in 1..=10 {
let _ = client.encrypt(format!("message {}", i).as_bytes()).unwrap();
let nonce = client.nonce_counter().unwrap();
assert_eq!(nonce, initial_nonce + i as u64, "Nonce should increment");
}
}
#[test]
fn test_rekey_operation() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client, mut server) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
for i in 0..5 {
let ct = client.encrypt(format!("msg {}", i).as_bytes()).unwrap();
let _pt = server.decrypt(&ct).unwrap();
}
client.rekey_outgoing().unwrap();
server.rekey_incoming().unwrap();
let ct = client.encrypt(b"after rekey").unwrap();
let pt = server.decrypt(&ct).unwrap();
assert_eq!(pt, b"after rekey");
}
#[test]
fn test_packet_checksum_validation() {
let packet = Packet::new(
PacketType::Data,
SequenceNumber(1),
SessionId::generate(),
1,
b"test payload".to_vec(),
)
.unwrap();
let mut encoded = packet.encode().unwrap();
encoded[4] ^= 0xFF;
let result = Packet::decode(&encoded);
assert!(
result.is_err(),
"Should detect corrupted header via checksum"
);
}
#[test]
fn test_packet_header_corruption() {
let packet = Packet::new(
PacketType::Data,
SequenceNumber(1),
SessionId::generate(),
1,
b"test".to_vec(),
)
.unwrap();
let mut encoded = packet.encode().unwrap();
encoded[4] ^= 0xFF;
let result = Packet::decode(&encoded);
assert!(result.is_err(), "Header corruption should be detected");
}
#[test]
fn test_packet_version_mismatch() {
let packet = Packet::new(
PacketType::Data,
SequenceNumber(1),
SessionId::generate(),
1,
b"test".to_vec(),
)
.unwrap();
let mut encoded = packet.encode().unwrap();
encoded[0] = 0xFF;
let result = Packet::decode(&encoded);
if let Ok(p) = result {
assert_ne!(
p.header.version,
triglav::PROTOCOL_VERSION,
"Should reject or flag wrong version"
);
}
}
#[test]
fn test_undersized_packet_rejected() {
let small = vec![0u8; HEADER_SIZE - 1];
let result = Packet::decode(&small);
assert!(result.is_err(), "Should reject undersized packet");
}
#[test]
fn test_encrypt_before_handshake_fails() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let mut client = NoiseSession::new_initiator(&client_kp.secret, &server_kp.public).unwrap();
let result = client.encrypt(b"test");
assert!(
result.is_err(),
"Should not encrypt before handshake complete"
);
}
#[test]
fn test_decrypt_before_handshake_fails() {
let server_kp = KeyPair::generate();
let mut server = NoiseSession::new_responder(&server_kp.secret).unwrap();
let fake_ciphertext = vec![0u8; 32];
let result = server.decrypt(&fake_ciphertext);
assert!(
result.is_err(),
"Should not decrypt before handshake complete"
);
}
#[test]
fn test_handshake_after_transport_fails() {
let client_kp = KeyPair::generate();
let server_kp = KeyPair::generate();
let (mut client, _server) =
complete_handshake(&client_kp.secret, &server_kp.secret, &server_kp.public);
let result = client.write_handshake(&[]);
assert!(
result.is_err(),
"Should not allow handshake after transport mode"
);
}
#[test]
fn test_signature_verification() {
use triglav::crypto::SigningKeyPair;
let keypair = SigningKeyPair::generate();
let message = b"important data";
let signature = keypair.sign(message);
assert!(keypair.verify(message, &signature).is_ok());
assert!(keypair.verify(b"different data", &signature).is_err());
}
#[test]
fn test_signature_tampering_detected() {
use triglav::crypto::SigningKeyPair;
let keypair = SigningKeyPair::generate();
let message = b"important data";
let mut signature = keypair.sign(message);
signature[0] ^= 0xFF;
assert!(keypair.verify(message, &signature).is_err());
}
#[test]
fn test_verify_with_wrong_public_key() {
use triglav::crypto::SigningKeyPair;
let keypair1 = SigningKeyPair::generate();
let keypair2 = SigningKeyPair::generate();
let message = b"data";
let signature = keypair1.sign(message);
let result = SigningKeyPair::verify_with_public(&keypair2.public_bytes(), message, &signature);
assert!(result.is_err(), "Should fail with wrong public key");
}
#[test]
fn test_secure_compare_equal() {
use triglav::crypto::secure_compare;
let a = [1u8, 2, 3, 4, 5];
let b = [1u8, 2, 3, 4, 5];
assert!(secure_compare(&a, &b), "Equal arrays should compare equal");
}
#[test]
fn test_secure_compare_unequal() {
use triglav::crypto::secure_compare;
let a = [1u8, 2, 3, 4, 5];
let b = [1u8, 2, 3, 4, 6];
assert!(
!secure_compare(&a, &b),
"Different arrays should not compare equal"
);
}
#[test]
fn test_secure_compare_different_lengths() {
use triglav::crypto::secure_compare;
let a = [1u8, 2, 3, 4, 5];
let b = [1u8, 2, 3];
assert!(
!secure_compare(&a, &b),
"Different length arrays should not compare equal"
);
}