use pqcrypto_falcon::{
falconpadded1024::{self},
ffi::{
PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_BYTES,
PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES,
PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_SECRETKEYBYTES,
},
};
use pqcrypto_mlkem::{
ffi::{
PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES, PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES,
PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES,
PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES,
},
mlkem1024::{self, SharedSecret},
};
use pqcrypto_traits::kem::{Ciphertext, PublicKey};
use pqcrypto_traits::sign::SignedMessage;
use rand::RngCore;
use crate::{
errors::CryptoError,
exchange::{
encryptor,
pair::{self, KEMPair, b2ss, ss2b},
},
signatures::keypair::{SignerPair, VerifierPair, ViewOperations},
};
const MAX_NONCE_COUNTER: u64 = u64::MAX - 1;
pub struct MessageSession {
kem_pair: pair::KEMPair,
ds_pair: SignerPair,
shared_secret: SharedSecret,
target_verifier: VerifierPair,
current_nonce: [u8; 24],
}
impl MessageSession {
pub fn to_bytes(&self) -> Result<Vec<u8>, CryptoError> {
let mut bytes = Vec::new();
bytes.extend_from_slice(self.kem_pair.to_bytes_uniform().as_slice());
bytes.extend_from_slice(self.ds_pair.to_bytes_uniform().as_slice());
bytes.extend_from_slice(&ss2b(&self.shared_secret));
bytes.extend_from_slice(&self.target_verifier.to_bytes());
bytes.extend_from_slice(&self.current_nonce[..]);
Ok(bytes)
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, CryptoError> {
let expected_length = PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES
+ PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES
+ PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES
+ PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_SECRETKEYBYTES
+ PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES
+ PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES
+ 24;
if bytes.len() != expected_length {
return Err(CryptoError::IncongruentLength(expected_length, bytes.len()));
}
let mut idx = 0;
let kem_pair = pair::KEMPair::from_bytes_uniform(
&bytes[idx..idx
+ PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES
+ PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES],
)?;
idx += PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES
+ PQCLEAN_MLKEM1024_CLEAN_CRYPTO_SECRETKEYBYTES;
let ds_pair = SignerPair::from_bytes_uniform(
&bytes[idx..idx
+ PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES
+ PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_SECRETKEYBYTES],
)?;
idx += PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES
+ PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_SECRETKEYBYTES;
let ss_bytes = &bytes[idx..idx + PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES];
let shared_secret = b2ss(parse_ss(ss_bytes)?);
idx += PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES;
let target_verifier = VerifierPair::from_bytes(
&bytes[idx..idx + PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES],
)?;
idx += PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES;
let current_nonce = bytes[idx..idx + 24].try_into().unwrap();
idx += 24;
if idx != bytes.len() {
return Err(CryptoError::IncongruentLength(bytes.len(), idx));
}
Ok(Self {
kem_pair,
ds_pair,
shared_secret,
target_verifier,
current_nonce,
})
}
pub fn new_initiator(
my_keypair: KEMPair, my_signer: SignerPair, base_nonce: [u8; 24], target_pubkey: &[u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_PUBLICKEYBYTES], target_verifier: &[u8; PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES], ) -> Result<(Self, [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES]), CryptoError> {
let pubkey = mlkem1024::PublicKey::from_bytes(target_pubkey)?;
let (shared_secret, ciphertext) = my_keypair.encapsulate(&pubkey);
let target_verifier = VerifierPair::from_bytes(target_verifier)?;
Ok((
Self {
kem_pair: my_keypair,
ds_pair: my_signer,
shared_secret,
target_verifier,
current_nonce: base_nonce,
},
ct2b(&ciphertext)?,
))
}
pub fn new_responder(
my_keypair: KEMPair, my_signer: SignerPair, base_nonce: [u8; 24], ciphertext_bytes: &[u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES], sender_verifier: &[u8; PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_PUBLICKEYBYTES], ) -> Result<Self, CryptoError> {
let ciphertext = Ciphertext::from_bytes(ciphertext_bytes)?;
let shared_secret = my_keypair.decapsulate(&ciphertext)?;
let target_verifier = VerifierPair::from_bytes(sender_verifier)?;
Ok(Self {
kem_pair: my_keypair,
ds_pair: my_signer,
shared_secret,
target_verifier,
current_nonce: base_nonce,
})
}
pub fn craft_message(&mut self, message: &[u8]) -> Result<Vec<u8>, CryptoError> {
let sig = self.ds_pair.sign(message);
self.increment_nonce();
encryptor::Encryptor::new(self.shared_secret).encrypt(&sig.as_bytes(), &self.current_nonce)
}
pub fn validate_message(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, CryptoError> {
self.increment_nonce();
let decrypted_message = encryptor::Encryptor::new(self.shared_secret)
.decrypt(ciphertext, &self.current_nonce)?;
if decrypted_message.len() < PQCLEAN_FALCONPADDED1024_CLEAN_CRYPTO_BYTES {
return Err(CryptoError::FalconSignatureTooShort(
decrypted_message.len(),
));
}
let sm = falconpadded1024::SignedMessage::from_bytes(&decrypted_message)?;
let msg = self.target_verifier.verify_message(&sm)?;
Ok(msg)
}
fn increment_nonce(&mut self) {
let mut counter = u64::from_le_bytes(self.current_nonce[16..24].try_into().unwrap());
if counter >= MAX_NONCE_COUNTER {
counter = 0;
} else {
counter += 1;
}
self.current_nonce[16..24].copy_from_slice(&counter.to_le_bytes());
}
pub fn get_counter(&self) -> u64 {
u64::from_le_bytes(self.current_nonce[16..24].try_into().unwrap())
}
}
fn ct2b(
ct: &mlkem1024::Ciphertext,
) -> Result<[u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES], CryptoError> {
let slice = ct.as_bytes();
if slice.len() == PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES {
let ptr = slice.as_ptr() as *const [u8; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES];
unsafe { Ok(*ptr) }
} else {
Err(CryptoError::IncongruentLength(
PQCLEAN_MLKEM1024_CLEAN_CRYPTO_CIPHERTEXTBYTES,
slice.len(),
))
}
}
pub fn parse_ss<T>(slice: &[T]) -> Result<&[T; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES], CryptoError> {
if slice.len() == PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES {
let ptr = slice.as_ptr() as *const [T; PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES];
unsafe { Ok(&*ptr) }
} else {
Err(CryptoError::IncongruentLength(
PQCLEAN_MLKEM1024_CLEAN_CRYPTO_BYTES,
slice.len(),
))
}
}
pub fn gen_session_id() -> [u8; 16] {
let mut session_id = [0u8; 16];
rand::rng().fill_bytes(&mut session_id);
session_id
}
pub fn create_nonce(session_id: &[u8; 16], counter: u64) -> [u8; 24] {
let mut nonce = [0u8; 24];
nonce[..16].copy_from_slice(session_id);
nonce[16..24].copy_from_slice(&counter.to_le_bytes());
nonce
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_session_serialization() {
let kem_pair = pair::KEMPair::create();
let ds_pair = SignerPair::create();
let target_kem_pair = pair::KEMPair::create();
let target_ds_pair = SignerPair::create();
let base_nonce = create_nonce(&gen_session_id(), 0);
let (session, _) = MessageSession::new_initiator(
kem_pair, ds_pair, base_nonce, &target_kem_pair.to_bytes().unwrap().0, &target_ds_pair.to_bytes().unwrap().0, )
.unwrap();
let serialized = session.to_bytes().unwrap();
let deserialized = MessageSession::from_bytes(&serialized).unwrap();
assert_eq!(session.current_nonce, deserialized.current_nonce);
}
#[test]
fn test_full_message_exchange() {
let alice_kem = pair::KEMPair::create();
let alice_ds = SignerPair::create();
let bob_kem = pair::KEMPair::create();
let bob_ds = SignerPair::create();
let base_nonce = create_nonce(&gen_session_id(), 0);
let (mut alice_session, ciphertext) = MessageSession::new_initiator(
alice_kem,
alice_ds.clone(),
base_nonce,
&bob_kem.to_bytes().unwrap().0,
&bob_ds.to_bytes().unwrap().0,
)
.unwrap();
let mut bob_session = MessageSession::new_responder(
bob_kem,
bob_ds.clone(),
base_nonce,
&ciphertext,
&alice_ds.to_bytes().unwrap().0,
)
.unwrap();
assert_eq!(
ss2b(&alice_session.shared_secret),
ss2b(&bob_session.shared_secret)
);
let message = b"Hello, Bob! This is a secret message.";
let encrypted_message = alice_session.craft_message(message).unwrap();
assert_eq!(
alice_session.current_nonce[16..24],
[1, 0, 0, 0, 0, 0, 0, 0]
);
assert_eq!(bob_session.current_nonce[16..24], [0, 0, 0, 0, 0, 0, 0, 0]);
let raw_message = bob_session.validate_message(&encrypted_message).unwrap();
assert_eq!(bob_session.current_nonce[16..24], [1, 0, 0, 0, 0, 0, 0, 0]);
assert_eq!(raw_message, message);
let reply = b"Hello, Alice! I received your message safely.";
let encrypted_reply = bob_session.craft_message(reply).unwrap();
let raw_reply = alice_session.validate_message(&encrypted_reply).unwrap();
assert_eq!(alice_session.current_nonce, bob_session.current_nonce);
assert_eq!(raw_reply, reply);
}
#[test]
fn test_nonce_desync() {
let alice_kem = pair::KEMPair::create();
let alice_ds = SignerPair::create();
let bob_kem = pair::KEMPair::create();
let bob_ds = SignerPair::create();
let base_nonce = create_nonce(&gen_session_id(), 0);
let (mut alice_session, ciphertext) = MessageSession::new_initiator(
alice_kem,
alice_ds.clone(),
base_nonce,
&bob_kem.to_bytes().unwrap().0,
&bob_ds.to_bytes().unwrap().0,
)
.unwrap();
let mut bob_session = MessageSession::new_responder(
bob_kem,
bob_ds.clone(),
base_nonce,
&ciphertext,
&alice_ds.to_bytes().unwrap().0,
)
.unwrap();
assert_eq!(
ss2b(&alice_session.shared_secret),
ss2b(&bob_session.shared_secret)
);
let message = b"Hello, Bob! This is a secret message.";
let encrypted_message = alice_session.craft_message(message).unwrap();
assert_eq!(
alice_session.current_nonce[16..24],
[1, 0, 0, 0, 0, 0, 0, 0]
);
assert_eq!(bob_session.current_nonce[16..24], [0, 0, 0, 0, 0, 0, 0, 0]);
bob_session.increment_nonce();
assert_eq!(bob_session.current_nonce[16..24], [1, 0, 0, 0, 0, 0, 0, 0]);
let result = bob_session.validate_message(&encrypted_message);
assert!(result.is_err());
}
#[test]
fn test_nonce_creation() {
let session_id = gen_session_id();
let counter = 5;
let nonce = create_nonce(&session_id, counter);
assert_eq!(&nonce[..16], &session_id[..]);
assert_eq!(&nonce[16..], &counter.to_le_bytes()[..]);
}
#[test]
fn test_nonce_increment() {
let session_id = gen_session_id();
let mut nonce = create_nonce(&session_id, 0);
let mut counter = u64::from_le_bytes(nonce[16..24].try_into().unwrap());
counter += 1;
nonce[16..24].copy_from_slice(&counter.to_le_bytes());
assert_eq!(&nonce[..16], &session_id[..]);
assert_eq!(u64::from_le_bytes(nonce[16..24].try_into().unwrap()), 1);
}
#[test]
fn test_nonce_increment_and_counter() {
let kem_pair = pair::KEMPair::create();
let ds_pair = SignerPair::create();
let target_kem_pair = pair::KEMPair::create();
let target_ds_pair = SignerPair::create();
let initial_counter = 42;
let base_nonce = create_nonce(&gen_session_id(), initial_counter);
let (mut session, _) = MessageSession::new_initiator(
kem_pair,
ds_pair,
base_nonce,
&target_kem_pair.to_bytes().unwrap().0,
&target_ds_pair.to_bytes().unwrap().0,
)
.unwrap();
let counter = session.get_counter();
assert_eq!(counter, initial_counter);
session.increment_nonce();
let new_counter = session.get_counter();
assert_eq!(new_counter, initial_counter + 1);
}
#[test]
fn test_counter_wraparound() {
let kem_pair = pair::KEMPair::create();
let ds_pair = SignerPair::create();
let target_kem_pair = pair::KEMPair::create();
let target_ds_pair = SignerPair::create();
let base_nonce = create_nonce(&gen_session_id(), MAX_NONCE_COUNTER);
let (mut session, _) = MessageSession::new_initiator(
kem_pair,
ds_pair,
base_nonce,
&target_kem_pair.to_bytes().unwrap().0,
&target_ds_pair.to_bytes().unwrap().0,
)
.unwrap();
assert_eq!(session.get_counter(), MAX_NONCE_COUNTER);
session.increment_nonce();
assert_eq!(session.get_counter(), 0);
}
#[test]
fn test_shared_secret_consistency() {
let alice_kem = pair::KEMPair::create();
let bob_kem = pair::KEMPair::create();
let pubkey = mlkem1024::PublicKey::from_bytes(&bob_kem.to_bytes().unwrap().0).unwrap();
let (alice_ss, ciphertext) = alice_kem.encapsulate(&pubkey);
let ciphertext_bytes = ct2b(&ciphertext).unwrap();
let ciphertext_received = mlkem1024::Ciphertext::from_bytes(&ciphertext_bytes).unwrap();
let bob_ss = bob_kem.decapsulate(&ciphertext_received).unwrap();
let alice_ss_bytes = ss2b(&alice_ss);
let bob_ss_bytes = ss2b(&bob_ss);
assert_eq!(alice_ss_bytes, bob_ss_bytes);
}
}