use crate::error::Result;
#[cfg(feature = "alloc")]
use crate::traits::*;
#[cfg(feature = "alloc")]
extern crate alloc;
#[cfg(feature = "alloc")]
use alloc::{
format,
string::String,
vec::Vec,
};
#[cfg(feature = "getrandom")]
#[allow(unused_imports)] use getrandom;
pub use lib_q_types::{
Algorithm,
AlgorithmCategory,
SecurityLevel,
};
#[cfg(any(feature = "getrandom", feature = "rand"))]
#[allow(unused_imports)]
use rand_core::Rng;
use subtle::ConstantTimeEq;
#[cfg(feature = "alloc")]
pub trait KemOperations {
fn generate_keypair(
&self,
algorithm: Algorithm,
randomness: Option<&[u8]>,
) -> Result<KemKeypair>;
fn encapsulate(
&self,
algorithm: Algorithm,
public_key: &KemPublicKey,
randomness: Option<&[u8]>,
) -> Result<(Vec<u8>, Vec<u8>)>;
fn decapsulate(
&self,
algorithm: Algorithm,
secret_key: &KemSecretKey,
ciphertext: &[u8],
) -> Result<Vec<u8>>;
fn derive_public_key(
&self,
algorithm: Algorithm,
secret_key: &KemSecretKey,
) -> Result<KemPublicKey>;
}
#[cfg(feature = "alloc")]
pub trait SignatureOperations {
fn generate_keypair(
&self,
algorithm: Algorithm,
randomness: Option<&[u8]>,
) -> Result<SigKeypair>;
fn sign(
&self,
algorithm: Algorithm,
secret_key: &SigSecretKey,
message: &[u8],
randomness: Option<&[u8]>,
) -> Result<Vec<u8>>;
fn verify(
&self,
algorithm: Algorithm,
public_key: &SigPublicKey,
message: &[u8],
signature: &[u8],
) -> Result<bool>;
}
#[cfg(feature = "alloc")]
pub trait HashOperations {
fn hash(&self, algorithm: Algorithm, data: &[u8]) -> Result<Vec<u8>>;
}
#[cfg(feature = "alloc")]
pub trait AeadOperations {
fn encrypt(
&self,
algorithm: Algorithm,
key: &AeadKey,
nonce: &Nonce,
plaintext: &[u8],
associated_data: Option<&[u8]>,
) -> Result<Vec<u8>>;
fn decrypt(
&self,
algorithm: Algorithm,
key: &AeadKey,
nonce: &Nonce,
ciphertext: &[u8],
associated_data: Option<&[u8]>,
) -> Result<Vec<u8>>;
}
pub trait CryptoProvider: Send + Sync {
#[cfg(feature = "alloc")]
fn kem(&self) -> Option<&dyn KemOperations>;
#[cfg(feature = "alloc")]
fn signature(&self) -> Option<&dyn SignatureOperations>;
#[cfg(feature = "alloc")]
fn hash(&self) -> Option<&dyn HashOperations>;
#[cfg(feature = "alloc")]
fn aead(&self) -> Option<&dyn AeadOperations>;
}
#[cfg(feature = "alloc")]
pub use crate::contexts::KemContext;
pub struct Utils;
impl Utils {
#[cfg(feature = "rand")]
pub fn random_bytes(length: usize) -> Result<Vec<u8>> {
const MIN_RANDOM_SIZE: usize = 1;
const MAX_RANDOM_SIZE: usize = 1024 * 1024; if !(MIN_RANDOM_SIZE..=MAX_RANDOM_SIZE).contains(&length) {
return Err(crate::error::Error::RandomBytesLengthInvalid {
min: MIN_RANDOM_SIZE,
max: MAX_RANDOM_SIZE,
requested: length,
});
}
let mut bytes = alloc::vec![0u8; length];
let mut rng = rand::rng();
rng.fill_bytes(&mut bytes);
Ok(bytes)
}
#[cfg(all(feature = "getrandom", not(feature = "rand")))]
#[cfg(feature = "alloc")]
pub fn random_bytes(length: usize) -> Result<Vec<u8>> {
const MIN_RANDOM_SIZE: usize = 1;
const MAX_RANDOM_SIZE: usize = 1024 * 1024; if !(MIN_RANDOM_SIZE..=MAX_RANDOM_SIZE).contains(&length) {
return Err(crate::error::Error::RandomBytesLengthInvalid {
min: MIN_RANDOM_SIZE,
max: MAX_RANDOM_SIZE,
requested: length,
});
}
let mut bytes = alloc::vec![0u8; length];
getrandom::fill(&mut bytes).map_err(|_| crate::error::Error::RandomGenerationFailed {
operation: String::from("random_bytes"),
})?;
Ok(bytes)
}
#[cfg(all(feature = "getrandom", not(feature = "rand")))]
#[cfg(not(feature = "alloc"))]
pub fn random_bytes(length: usize) -> Result<&'static [u8]> {
const MIN_RANDOM_SIZE: usize = 1;
const MAX_RANDOM_SIZE: usize = 1024; if !(MIN_RANDOM_SIZE..=MAX_RANDOM_SIZE).contains(&length) {
return Err(crate::error::Error::RandomBytesLengthInvalid {
min: MIN_RANDOM_SIZE,
max: MAX_RANDOM_SIZE,
requested: length,
});
}
#[cfg(target_arch = "wasm32")]
{
return Err(crate::error::Error::RandomGenerationFailed {
operation: "random_bytes",
});
}
#[cfg(not(target_arch = "wasm32"))]
{
return Err(crate::error::Error::RandomGenerationFailed {
operation: "random_bytes",
});
}
}
#[cfg(not(any(feature = "rand", feature = "getrandom")))]
#[cfg(feature = "alloc")]
pub fn random_bytes(_length: usize) -> Result<Vec<u8>> {
Err(crate::error::Error::RandomGenerationFailed {
operation: String::from("random_bytes"),
})
}
#[cfg(not(any(feature = "rand", feature = "getrandom")))]
#[cfg(not(feature = "alloc"))]
pub fn random_bytes(_length: usize) -> Result<&'static [u8]> {
Err(crate::error::Error::RandomGenerationFailed {
operation: "random_bytes",
})
}
#[cfg(feature = "alloc")]
pub fn bytes_to_hex(bytes: &[u8]) -> String {
let mut hex = String::new();
for &byte in bytes {
hex.push_str(&format!("{:02x}", byte));
}
hex
}
#[cfg(not(feature = "alloc"))]
pub fn bytes_to_hex(_bytes: &[u8]) -> &'static str {
"hex conversion not available in no_std without alloc"
}
#[cfg(feature = "alloc")]
pub fn hex_to_bytes(hex: &str) -> Result<Vec<u8>> {
use crate::error::HexDecodeError;
let hex = hex.trim();
if !hex.len().is_multiple_of(2) {
return Err(crate::error::Error::HexDecode(HexDecodeError::OddLength {
char_count: hex.len(),
}));
}
let mut bytes = Vec::with_capacity(hex.len() / 2);
for i in (0..hex.len()).step_by(2) {
let byte = u8::from_str_radix(&hex[i..i + 2], 16).map_err(|_| {
crate::error::Error::HexDecode(HexDecodeError::InvalidDigit {
pair_start: i,
char_count: hex.len(),
})
})?;
bytes.push(byte);
}
Ok(bytes)
}
#[cfg(not(feature = "alloc"))]
pub fn hex_to_bytes(_hex: &str) -> Result<&'static [u8]> {
Err(crate::error::Error::MemoryAllocationFailed {
operation: "hex_to_bytes",
})
}
pub fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "alloc")]
use crate::contexts::{
HashContext,
SignatureContext,
};
#[test]
fn test_provider_architecture() {
#[cfg(feature = "std")]
{
let mut ctx = KemContext::with_default_provider();
let result = ctx.generate_keypair(Algorithm::MlKem512, None);
assert!(result.is_err());
match result {
Err(crate::error::Error::NotImplemented { feature }) => {
assert!(
feature.contains(
"ML-KEM implementations are provided by the main lib-q crate"
)
);
}
Err(crate::error::Error::ProviderNotConfigured { operation }) => {
assert_eq!(operation, "KEM");
}
_ => panic!("Expected NotImplemented or ProviderNotConfigured"),
}
}
#[cfg(feature = "alloc")]
{
let mut ctx = KemContext::new();
let result = ctx.generate_keypair(Algorithm::MlKem512, None);
assert!(result.is_err());
if let Err(crate::error::Error::ProviderNotConfigured { operation }) = result {
assert_eq!(operation, "KEM");
} else {
panic!("Expected ProviderNotConfigured error, got different error type");
}
}
}
#[test]
fn test_algorithm_security_levels() {
assert_eq!(Algorithm::MlKem512.security_level(), 1);
assert_eq!(Algorithm::MlKem768.security_level(), 3);
assert_eq!(Algorithm::MlKem1024.security_level(), 4);
assert_eq!(Algorithm::MlDsa44.security_level(), 1);
assert_eq!(Algorithm::MlDsa65.security_level(), 3);
assert_eq!(Algorithm::MlDsa87.security_level(), 4);
}
#[test]
fn test_algorithm_categories() {
assert_eq!(Algorithm::MlKem512.category(), AlgorithmCategory::Kem);
assert_eq!(Algorithm::MlDsa44.category(), AlgorithmCategory::Signature);
assert_eq!(Algorithm::Shake256.category(), AlgorithmCategory::Hash);
}
#[test]
#[cfg(feature = "alloc")]
fn test_kem_context() {
let mut ctx = KemContext::new();
let result = ctx.generate_keypair(Algorithm::MlKem512, None);
assert!(result.is_err());
if let Err(crate::error::Error::ProviderNotConfigured { operation }) = result {
assert_eq!(operation, "KEM");
} else {
panic!("Expected ProviderNotConfigured error");
}
}
#[test]
#[cfg(feature = "alloc")]
fn test_signature_context() {
let mut ctx = SignatureContext::new();
let result = ctx.generate_keypair(Algorithm::MlDsa65, None);
assert!(result.is_err());
if let Err(crate::error::Error::ProviderNotConfigured { operation }) = result {
assert_eq!(operation, "signature");
} else {
panic!("Expected ProviderNotConfigured error");
}
}
#[test]
#[cfg(feature = "alloc")]
fn test_hash_context() {
let mut ctx = HashContext::new();
let result = ctx.hash(Algorithm::Shake256, b"test");
assert!(result.is_err());
if let Err(crate::error::Error::ProviderNotConfigured { operation }) = result {
assert_eq!(operation, "hash");
} else {
panic!("Expected ProviderNotConfigured error");
}
}
#[test]
fn test_utils() {
#[cfg(feature = "getrandom")]
{
let bytes = Utils::random_bytes(32).unwrap();
assert_eq!(bytes.len(), 32);
}
#[cfg(feature = "alloc")]
{
let hex = Utils::bytes_to_hex(&[0x01, 0x23, 0x45, 0x67]);
assert_eq!(hex, "01234567");
let decoded = Utils::hex_to_bytes(&hex).unwrap();
assert_eq!(decoded, alloc::vec![0x01, 0x23, 0x45, 0x67]);
}
}
#[test]
fn test_random_bytes_generation() {
match Utils::random_bytes(32) {
Ok(bytes1) => {
let bytes2 = Utils::random_bytes(32).expect("Should generate random bytes");
assert_eq!(bytes1.len(), 32);
assert_eq!(bytes2.len(), 32);
assert_ne!(
bytes1, bytes2,
"Random bytes should be different on subsequent calls"
);
let all_zero1 = bytes1.iter().all(|&b| b == 0);
let all_zero2 = bytes2.iter().all(|&b| b == 0);
assert!(!all_zero1, "Random bytes should not all be zero");
assert!(!all_zero2, "Random bytes should not all be zero");
}
Err(crate::error::Error::RandomGenerationFailed { .. }) => {
}
Err(e) => {
panic!("Unexpected error: {:?}", e);
}
}
}
#[test]
fn test_constant_time_compare() {
assert!(Utils::constant_time_compare(b"hello", b"hello"));
assert!(!Utils::constant_time_compare(b"hello", b"world"));
assert!(!Utils::constant_time_compare(b"hello", b"hell"));
}
#[cfg(feature = "getrandom")]
#[test]
fn test_random_bytes_entropy_quality() {
const NUM_SAMPLES: usize = 1000;
const BYTE_LENGTH: usize = 32;
let mut byte_counts = [0u32; 256];
let mut total_bytes = 0u32;
for _ in 0..NUM_SAMPLES {
let bytes = Utils::random_bytes(BYTE_LENGTH).expect("Should generate random bytes");
for &byte in &bytes {
byte_counts[byte as usize] += 1;
total_bytes += 1;
}
}
let zero_count = byte_counts.iter().filter(|&&count| count == 0).count();
assert!(
zero_count < 50,
"Too many byte values are missing from random generation"
);
let expected_per_byte = total_bytes as f64 / 256.0;
let chi_sq: f64 = byte_counts
.iter()
.map(|&count| {
let d = count as f64 - expected_per_byte;
d * d / expected_per_byte
})
.sum();
const NU: f64 = 255.0;
let z =
((chi_sq / NU).powf(1.0 / 3.0) - (1.0 - 2.0 / (9.0 * NU))) / (2.0 / (9.0 * NU)).sqrt();
assert!(
z <= 5.0,
"Random bytes show poor entropy distribution (chi-square z = {})",
z
);
}
#[cfg(feature = "getrandom")]
#[test]
fn test_random_bytes_uniformity() {
const NUM_SAMPLES: usize = 10000;
const BYTE_LENGTH: usize = 16;
let mut all_bytes = alloc::vec![0u8; NUM_SAMPLES * BYTE_LENGTH];
let mut offset = 0;
for _ in 0..NUM_SAMPLES {
let bytes = Utils::random_bytes(BYTE_LENGTH).expect("Should generate random bytes");
all_bytes[offset..offset + BYTE_LENGTH].copy_from_slice(&bytes);
offset += BYTE_LENGTH;
}
let mut max_run_length = 0;
let mut current_run_length = 1;
for i in 1..all_bytes.len() {
if all_bytes[i] == all_bytes[i - 1] {
current_run_length += 1;
max_run_length = max_run_length.max(current_run_length);
} else {
current_run_length = 1;
}
}
assert!(
max_run_length <= 4,
"Random bytes show suspicious patterns (run length: {})",
max_run_length
);
}
#[cfg(any(feature = "rand", all(feature = "getrandom", feature = "alloc")))]
#[test]
fn test_random_bytes_size_limits() {
const MAX_SIZE: usize = 1024 * 1024; assert_eq!(
Utils::random_bytes(0),
Err(crate::error::Error::RandomBytesLengthInvalid {
min: 1,
max: MAX_SIZE,
requested: 0,
}),
"zero length"
);
assert!(
Utils::random_bytes(MAX_SIZE).is_ok(),
"Should accept maximum size"
);
assert_eq!(
Utils::random_bytes(MAX_SIZE + 1),
Err(crate::error::Error::RandomBytesLengthInvalid {
min: 1,
max: MAX_SIZE,
requested: MAX_SIZE + 1,
}),
"oversized request"
);
for size in [1, 16, 32, 64, 128, 256, 512, 1024] {
let bytes = Utils::random_bytes(size).expect("Should generate random bytes");
assert_eq!(bytes.len(), size, "Should generate exactly {} bytes", size);
}
}
#[test]
#[cfg(feature = "alloc")]
fn test_hex_to_bytes_decode_errors() {
use crate::error::{
Error,
HexDecodeError,
};
assert_eq!(
Utils::hex_to_bytes("123").unwrap_err(),
Error::HexDecode(HexDecodeError::OddLength { char_count: 3 })
);
assert_eq!(
Utils::hex_to_bytes("12g3").unwrap_err(),
Error::HexDecode(HexDecodeError::InvalidDigit {
pair_start: 2,
char_count: 4,
})
);
}
}