use crate::crypto::pqc::types::*;
use aws_lc_rs::hkdf;
use aws_lc_rs::hmac;
pub struct ConcatenationCombiner;
impl ConcatenationCombiner {
pub fn combine(
classical_secret: &[u8],
pqc_secret: &[u8],
info: &[u8],
) -> PqcResult<SharedSecret> {
let mut concatenated = Vec::with_capacity(classical_secret.len() + pqc_secret.len());
concatenated.extend_from_slice(classical_secret);
concatenated.extend_from_slice(pqc_secret);
let salt = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]);
let prk = salt.extract(&concatenated);
let mut output = [0u8; 32];
prk.expand(&[info], hkdf::HKDF_SHA256)
.map_err(|_| PqcError::CryptoError("HKDF expand failed".to_string()))?
.fill(&mut output)
.map_err(|_| PqcError::CryptoError("HKDF fill failed".to_string()))?;
Ok(SharedSecret(output))
}
pub fn combine_with_salt(
classical_secret: &[u8],
pqc_secret: &[u8],
salt: &[u8],
info: &[u8],
) -> PqcResult<SharedSecret> {
let mut concatenated = Vec::with_capacity(classical_secret.len() + pqc_secret.len());
concatenated.extend_from_slice(classical_secret);
concatenated.extend_from_slice(pqc_secret);
let hkdf_salt = hkdf::Salt::new(hkdf::HKDF_SHA256, salt);
let prk = hkdf_salt.extract(&concatenated);
let mut output = [0u8; 32];
prk.expand(&[info], hkdf::HKDF_SHA256)
.map_err(|_| PqcError::CryptoError("HKDF expand failed".to_string()))?
.fill(&mut output)
.map_err(|_| PqcError::CryptoError("HKDF fill failed".to_string()))?;
Ok(SharedSecret(output))
}
}
pub struct TwoStepCombiner;
impl TwoStepCombiner {
pub fn combine(
classical_secret: &[u8],
pqc_secret: &[u8],
info: &[u8],
) -> PqcResult<SharedSecret> {
let salt = hkdf::Salt::new(hkdf::HKDF_SHA256, &[]);
let prk_classical = salt.extract(classical_secret);
let mut classical_prk_bytes = vec![0u8; 32];
prk_classical
.expand(&[], hkdf::HKDF_SHA256)
.map_err(|_| PqcError::CryptoError("HKDF expand failed".to_string()))?
.fill(&mut classical_prk_bytes)
.map_err(|_| PqcError::CryptoError("HKDF fill failed".to_string()))?;
let salt_pqc = hkdf::Salt::new(hkdf::HKDF_SHA256, &classical_prk_bytes);
let prk_combined = salt_pqc.extract(pqc_secret);
let mut output = [0u8; 32];
prk_combined
.expand(&[info], hkdf::HKDF_SHA256)
.map_err(|_| PqcError::CryptoError("HKDF expand failed".to_string()))?
.fill(&mut output)
.map_err(|_| PqcError::CryptoError("HKDF fill failed".to_string()))?;
Ok(SharedSecret(output))
}
}
pub struct HmacCombiner;
impl HmacCombiner {
pub fn combine(
classical_secret: &[u8],
pqc_secret: &[u8],
info: &[u8],
) -> PqcResult<SharedSecret> {
let key = hmac::Key::new(hmac::HMAC_SHA256, classical_secret);
let mut message = Vec::with_capacity(pqc_secret.len() + info.len());
message.extend_from_slice(pqc_secret);
message.extend_from_slice(info);
let tag = hmac::sign(&key, &message);
let mut output = [0u8; 32];
output.copy_from_slice(tag.as_ref());
Ok(SharedSecret(output))
}
}
pub trait HybridCombiner: Send + Sync {
fn combine(
&self,
classical_secret: &[u8],
pqc_secret: &[u8],
info: &[u8],
) -> PqcResult<SharedSecret>;
fn algorithm_name(&self) -> &'static str;
}
impl HybridCombiner for ConcatenationCombiner {
fn combine(
&self,
classical_secret: &[u8],
pqc_secret: &[u8],
info: &[u8],
) -> PqcResult<SharedSecret> {
Self::combine(classical_secret, pqc_secret, info)
}
fn algorithm_name(&self) -> &'static str {
"NIST-SP-800-56C-Option1-Concatenation"
}
}
impl HybridCombiner for TwoStepCombiner {
fn combine(
&self,
classical_secret: &[u8],
pqc_secret: &[u8],
info: &[u8],
) -> PqcResult<SharedSecret> {
Self::combine(classical_secret, pqc_secret, info)
}
fn algorithm_name(&self) -> &'static str {
"NIST-SP-800-56C-Option2-TwoStep"
}
}
impl HybridCombiner for HmacCombiner {
fn combine(
&self,
classical_secret: &[u8],
pqc_secret: &[u8],
info: &[u8],
) -> PqcResult<SharedSecret> {
Self::combine(classical_secret, pqc_secret, info)
}
fn algorithm_name(&self) -> &'static str {
"HMAC-SHA256-Combiner"
}
}
pub fn default_combiner() -> Box<dyn HybridCombiner> {
Box::new(ConcatenationCombiner)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_concatenation_combiner() {
let classical = [1u8; 32];
let pqc = [2u8; 32];
let info = b"test info";
let result = ConcatenationCombiner::combine(&classical, &pqc, info);
assert!(result.is_ok());
let secret = result.unwrap();
assert_eq!(secret.as_bytes().len(), 32);
let result2 = ConcatenationCombiner::combine(&classical, &pqc, info);
assert_eq!(secret.as_bytes(), result2.unwrap().as_bytes());
let different_classical = [3u8; 32];
let result3 = ConcatenationCombiner::combine(&different_classical, &pqc, info);
assert_ne!(secret.as_bytes(), result3.unwrap().as_bytes());
}
#[test]
fn test_concatenation_combiner_with_salt() {
let classical = [1u8; 32];
let pqc = [2u8; 32];
let salt = b"test salt";
let info = b"test info";
let result = ConcatenationCombiner::combine_with_salt(&classical, &pqc, salt, info);
assert!(result.is_ok());
let secret = result.unwrap();
assert_eq!(secret.as_bytes().len(), 32);
let different_salt = b"different salt";
let result2 =
ConcatenationCombiner::combine_with_salt(&classical, &pqc, different_salt, info);
assert_ne!(secret.as_bytes(), result2.unwrap().as_bytes());
}
#[test]
fn test_two_step_combiner() {
let classical = [1u8; 32];
let pqc = [2u8; 32];
let info = b"test info";
let result = TwoStepCombiner::combine(&classical, &pqc, info);
assert!(result.is_ok());
let secret = result.unwrap();
assert_eq!(secret.as_bytes().len(), 32);
let result2 = TwoStepCombiner::combine(&classical, &pqc, info);
assert_eq!(secret.as_bytes(), result2.unwrap().as_bytes());
}
#[test]
fn test_hmac_combiner() {
let classical = [1u8; 32];
let pqc = [2u8; 32];
let info = b"test info";
let result = HmacCombiner::combine(&classical, &pqc, info);
assert!(result.is_ok());
let secret = result.unwrap();
assert_eq!(secret.as_bytes().len(), 32);
let result2 = HmacCombiner::combine(&classical, &pqc, info);
assert_eq!(secret.as_bytes(), result2.unwrap().as_bytes());
}
#[test]
fn test_different_combiners_produce_different_outputs() {
let classical = [1u8; 32];
let pqc = [2u8; 32];
let info = b"test info";
let concat_result = ConcatenationCombiner::combine(&classical, &pqc, info).unwrap();
let twostep_result = TwoStepCombiner::combine(&classical, &pqc, info).unwrap();
let hmac_result = HmacCombiner::combine(&classical, &pqc, info).unwrap();
assert_ne!(concat_result.as_bytes(), twostep_result.as_bytes());
assert_ne!(concat_result.as_bytes(), hmac_result.as_bytes());
assert_ne!(twostep_result.as_bytes(), hmac_result.as_bytes());
}
#[test]
fn test_hybrid_combiner_trait() {
let combiner: Box<dyn HybridCombiner> = Box::new(ConcatenationCombiner);
assert_eq!(
combiner.algorithm_name(),
"NIST-SP-800-56C-Option1-Concatenation"
);
let classical = [1u8; 32];
let pqc = [2u8; 32];
let info = b"test info";
let result = combiner.combine(&classical, &pqc, info);
assert!(result.is_ok());
}
#[test]
fn test_default_combiner() {
let combiner = default_combiner();
assert_eq!(
combiner.algorithm_name(),
"NIST-SP-800-56C-Option1-Concatenation"
);
}
#[test]
fn test_combiner_with_various_sizes() {
let classical_p256 = [1u8; 32]; let classical_p384 = [1u8; 48]; let pqc = [2u8; 32]; let info = b"test info";
let result1 = ConcatenationCombiner::combine(&classical_p256, &pqc, info);
assert!(result1.is_ok());
let result2 = ConcatenationCombiner::combine(&classical_p384, &pqc, info);
assert!(result2.is_ok());
assert_ne!(result1.unwrap().as_bytes(), result2.unwrap().as_bytes());
}
#[test]
fn test_empty_info() {
let classical = [1u8; 32];
let pqc = [2u8; 32];
let empty_info = b"";
let result = ConcatenationCombiner::combine(&classical, &pqc, empty_info);
assert!(result.is_ok());
}
#[test]
fn test_large_info() {
let classical = [1u8; 32];
let pqc = [2u8; 32];
let large_info = vec![0u8; 1024];
let result = ConcatenationCombiner::combine(&classical, &pqc, &large_info);
assert!(result.is_ok());
}
}