#[cfg(feature = "alloc")]
use alloc::string::ToString;
use super::{
EntropyValidator,
SecurityConstants,
TimingValidator,
};
use crate::api::{
Algorithm,
AlgorithmCategory,
};
use crate::error::Result;
#[cfg(feature = "alloc")]
#[derive(Clone)]
pub struct SecurityValidator {
timing_validator: TimingValidator,
entropy_validator: EntropyValidator,
constants: SecurityConstants,
}
#[cfg(feature = "alloc")]
impl SecurityValidator {
pub fn new() -> Result<Self> {
Ok(Self {
timing_validator: TimingValidator::new()?,
entropy_validator: EntropyValidator::new()?,
constants: SecurityConstants::new(),
})
}
pub fn validate_algorithm_category(
&self,
algorithm: Algorithm,
expected_category: AlgorithmCategory,
) -> Result<()> {
if !algorithm.supports_category(expected_category) {
return Err(crate::error::Error::InvalidAlgorithm {
algorithm: "Algorithm category mismatch",
});
}
Ok(())
}
pub fn validate_key_size(
&self,
algorithm: Algorithm,
key_data: &[u8],
is_secret: bool,
) -> Result<()> {
let expected_size = self.constants.get_expected_key_size(algorithm, is_secret)?;
if key_data.len() != expected_size {
return Err(crate::error::Error::InvalidKeySize {
expected: expected_size,
actual: key_data.len(),
});
}
Ok(())
}
fn ensure_non_trivial_key_bytes(key_data: &[u8]) -> Result<()> {
if key_data.is_empty() {
return Err(crate::error::Error::InvalidKeySize {
expected: 1,
actual: 0,
});
}
if key_data.iter().all(|&b| b == 0) {
return Err(crate::error::Error::InvalidKey {
key_type: "key".to_string(),
reason: "Key material cannot be all zeros".to_string(),
});
}
if key_data.iter().all(|&b| b == 0xFF) {
return Err(crate::error::Error::InvalidKey {
key_type: "key".to_string(),
reason: "Key material cannot be all ones".to_string(),
});
}
Ok(())
}
pub fn validate_key_material(&self, key_data: &[u8]) -> Result<()> {
Self::ensure_non_trivial_key_bytes(key_data)?;
self.entropy_validator.validate_key_entropy(key_data)?;
Ok(())
}
pub fn validate_public_key(&self, algorithm: Algorithm, key_data: &[u8]) -> Result<()> {
self.validate_key_size(algorithm, key_data, false)?;
self.validate_key_material(key_data)?;
Ok(())
}
pub fn validate_secret_key(&self, algorithm: Algorithm, key_data: &[u8]) -> Result<()> {
self.validate_key_size(algorithm, key_data, true)?;
Self::ensure_non_trivial_key_bytes(key_data)?;
self.entropy_validator.validate_key_entropy(key_data)?;
Ok(())
}
pub fn validate_aead_message(&self, message: &[u8]) -> Result<()> {
if message.len() > self.constants.max_aead_message_size() {
return Err(crate::error::Error::InvalidMessageSize {
max: self.constants.max_aead_message_size(),
actual: message.len(),
});
}
Ok(())
}
pub fn validate_hash_input(&self, data: &[u8]) -> Result<()> {
if data.len() > self.constants.max_hash_message_size() {
return Err(crate::error::Error::InvalidMessageSize {
max: self.constants.max_hash_message_size(),
actual: data.len(),
});
}
Ok(())
}
pub fn validate_signature_message(&self, message: &[u8]) -> Result<()> {
self.validate_hash_input(message)
}
pub fn security_constants(&self) -> &SecurityConstants {
&self.constants
}
pub fn security_constants_mut(&mut self) -> &mut SecurityConstants {
&mut self.constants
}
pub fn validate_nonce(&self, nonce: &[u8]) -> Result<()> {
if nonce.len() != self.constants.standard_nonce_size() {
return Err(crate::error::Error::InvalidNonceSize {
expected: self.constants.standard_nonce_size(),
actual: nonce.len(),
});
}
if nonce.iter().all(|&b| b == 0) {
return Err(crate::error::Error::InvalidKey {
key_type: "nonce".to_string(),
reason: "Nonce cannot be all zeros".to_string(),
});
}
Ok(())
}
pub fn validate_ciphertext(&self, algorithm: Algorithm, ciphertext: &[u8]) -> Result<()> {
if ciphertext.is_empty() {
return Err(crate::error::Error::InvalidCiphertextSize {
expected: 1,
actual: 0,
});
}
let expected_size = self.constants.get_expected_ciphertext_size(algorithm)?;
if ciphertext.len() != expected_size {
return Err(crate::error::Error::InvalidCiphertextSize {
expected: expected_size,
actual: ciphertext.len(),
});
}
Ok(())
}
pub fn validate_signature(&self, algorithm: Algorithm, signature: &[u8]) -> Result<()> {
if signature.is_empty() {
return Err(crate::error::Error::InvalidSignatureSize {
expected: 1,
actual: 0,
});
}
let expected_size = self.constants.get_expected_signature_size(algorithm)?;
if signature.len() != expected_size {
return Err(crate::error::Error::InvalidSignatureSize {
expected: expected_size,
actual: signature.len(),
});
}
Ok(())
}
pub fn validate_randomness(&self, randomness: &[u8]) -> Result<()> {
if randomness.len() < self.constants.min_randomness_size() {
return Err(crate::error::Error::InvalidKeySize {
expected: self.constants.min_randomness_size(),
actual: randomness.len(),
});
}
self.validate_key_material(randomness)?;
Ok(())
}
pub fn constant_time_compare(&self, a: &[u8], b: &[u8]) -> bool {
self.timing_validator.constant_time_compare(a, b)
}
pub fn entropy_validator(&self) -> &EntropyValidator {
&self.entropy_validator
}
pub fn entropy_validator_mut(&mut self) -> &mut EntropyValidator {
&mut self.entropy_validator
}
}
#[cfg(test)]
#[cfg(feature = "alloc")]
mod tests {
use super::*;
#[test]
fn test_security_validator_creation() {
let validator = SecurityValidator::new();
assert!(
validator.is_ok(),
"SecurityValidator should be created successfully"
);
}
#[test]
fn test_validate_algorithm_category() {
let validator = SecurityValidator::new().unwrap();
let result =
validator.validate_algorithm_category(Algorithm::MlKem512, AlgorithmCategory::Kem);
assert!(result.is_ok(), "Should accept correct algorithm category");
let result = validator
.validate_algorithm_category(Algorithm::MlKem512, AlgorithmCategory::Signature);
assert!(
result.is_err(),
"Should reject incorrect algorithm category"
);
}
#[test]
fn test_validate_key_material() {
let validator = SecurityValidator::new().unwrap();
let valid_key = vec![
0x1A, 0x2B, 0x3C, 0x4D, 0x5E, 0x6F, 0x70, 0x81, 0x92, 0xA3, 0xB4, 0xC5, 0xD6, 0xE7,
0xF8, 0x09,
];
let result = validator.validate_key_material(&valid_key);
assert!(result.is_ok(), "Should accept valid key material");
let zero_key = vec![0u8; 8];
let result = validator.validate_key_material(&zero_key);
assert!(result.is_err(), "Should reject zero key");
let ones_key = vec![0xFFu8; 8];
let result = validator.validate_key_material(&ones_key);
assert!(result.is_err(), "Should reject all-ones key");
let empty_key = vec![];
let result = validator.validate_key_material(&empty_key);
assert!(result.is_err(), "Should reject empty key");
}
#[test]
fn test_validate_aead_message() {
let validator = SecurityValidator::new().unwrap();
let valid_message = vec![1u8; 1000];
assert!(validator.validate_aead_message(&valid_message).is_ok());
let oversized_message = vec![1u8; 2 * 1024 * 1024];
assert!(validator.validate_aead_message(&oversized_message).is_err());
}
#[test]
fn test_validate_hash_input_default_unbounded() {
let validator = SecurityValidator::new().unwrap();
let large = vec![1u8; 2 * 1024 * 1024];
assert!(validator.validate_hash_input(&large).is_ok());
}
#[test]
fn test_validate_hash_input_when_capped() {
let mut validator = SecurityValidator::new().unwrap();
validator
.security_constants_mut()
.set_max_hash_message_size(1024);
let oversized = vec![1u8; 2048];
assert!(validator.validate_hash_input(&oversized).is_err());
}
#[test]
fn test_constant_time_compare() {
let validator = SecurityValidator::new().unwrap();
let a = vec![1, 2, 3, 4];
let b = vec![1, 2, 3, 4];
assert!(
validator.constant_time_compare(&a, &b),
"Should return true for equal slices"
);
let c = vec![1, 2, 3, 5];
assert!(
!validator.constant_time_compare(&a, &c),
"Should return false for different slices"
);
let d = vec![1, 2, 3];
assert!(
!validator.constant_time_compare(&a, &d),
"Should return false for different length slices"
);
}
}