use alloc::boxed::Box;
use alloc::collections::BTreeMap;
#[cfg(feature = "alloc")]
use alloc::string::ToString;
use alloc::vec::Vec;
#[cfg(feature = "std")]
use core::hash::Hasher;
#[cfg(feature = "std")]
#[derive(Default)]
#[allow(dead_code)] struct AlgorithmHasher {
state: u64,
}
#[cfg(feature = "std")]
impl Hasher for AlgorithmHasher {
fn write(&mut self, bytes: &[u8]) {
for &byte in bytes {
self.state = self.state.wrapping_mul(31).wrapping_add(byte as u64);
}
}
fn finish(&self) -> u64 {
self.state
}
}
type AlgorithmHashMap = BTreeMap<Algorithm, AeadConstructor>;
#[cfg(feature = "std")]
use std::sync::RwLock;
use lib_q_core::{
Algorithm,
AlgorithmCategory,
Error,
Result,
};
#[cfg(not(feature = "std"))]
use spin::RwLock;
use crate::AeadWithMetadata;
use crate::metadata::AeadMetadata;
use crate::plugin::AeadPlugin;
pub type AeadConstructor = Box<dyn Fn() -> Result<Box<dyn AeadWithMetadata>> + Send + Sync>;
pub struct AeadRegistry {
constructors: RwLock<AlgorithmHashMap>,
plugins: RwLock<Vec<Box<dyn AeadPlugin>>>,
metadata: BTreeMap<Algorithm, &'static AeadMetadata>,
}
impl AeadRegistry {
pub fn new() -> Self {
Self {
constructors: RwLock::new(AlgorithmHashMap::new()),
plugins: RwLock::new(Vec::new()),
metadata: Self::create_metadata_map(),
}
}
fn create_metadata_map() -> BTreeMap<Algorithm, &'static AeadMetadata> {
let mut metadata = BTreeMap::new();
let known_algorithms = [
Algorithm::Saturnin,
Algorithm::Shake256Aead,
Algorithm::DuplexSpongeAead,
Algorithm::TweakAead,
Algorithm::RomulusN,
Algorithm::RomulusM,
];
for algorithm in known_algorithms {
if let Some(algorithm_metadata) = crate::metadata::get_metadata(algorithm) {
metadata.insert(algorithm, algorithm_metadata);
}
}
metadata
}
pub fn register_algorithm<F>(&self, algorithm: Algorithm, constructor: F) -> Result<()>
where
F: Fn() -> Result<Box<dyn AeadWithMetadata>> + Send + Sync + 'static,
{
if algorithm.category() != AlgorithmCategory::Aead {
return Err(Error::InvalidAlgorithm {
algorithm: "Algorithm is not an AEAD algorithm",
});
}
#[cfg(feature = "std")]
{
let mut constructors = self.constructors.write().map_err(|_| Error::InvalidState {
operation: "register_algorithm".to_string(),
reason: "Failed to acquire write lock".to_string(),
})?;
constructors.insert(algorithm, Box::new(constructor));
}
#[cfg(not(feature = "std"))]
{
let mut constructors = self.constructors.write();
constructors.insert(algorithm, Box::new(constructor));
}
Ok(())
}
pub fn register_plugin(&self, plugin: Box<dyn AeadPlugin>) -> Result<()> {
#[cfg(feature = "std")]
{
let mut plugins = self.plugins.write().map_err(|_| Error::InvalidState {
operation: "register_plugin".to_string(),
reason: "Failed to acquire write lock".to_string(),
})?;
plugins.push(plugin);
}
#[cfg(not(feature = "std"))]
{
let mut plugins = self.plugins.write();
plugins.push(plugin);
}
Ok(())
}
pub fn create_aead(&self, algorithm: Algorithm) -> Result<Box<dyn AeadWithMetadata>> {
#[cfg(feature = "std")]
{
let constructors = self.constructors.read().map_err(|_| Error::InvalidState {
operation: "create_aead".to_string(),
reason: "Failed to acquire read lock".to_string(),
})?;
if let Some(constructor) = constructors.get(&algorithm) {
return constructor();
}
}
#[cfg(not(feature = "std"))]
{
let constructors = self.constructors.read();
if let Some(constructor) = constructors.get(&algorithm) {
return constructor();
}
}
#[cfg(feature = "std")]
{
let plugins = self.plugins.read().map_err(|_| Error::InvalidState {
operation: "create_aead".to_string(),
reason: "Failed to acquire read lock".to_string(),
})?;
for plugin in plugins.iter() {
if plugin.algorithm() == algorithm {
return plugin.create();
}
}
}
#[cfg(not(feature = "std"))]
{
let plugins = self.plugins.read();
for plugin in plugins.iter() {
if plugin.algorithm() == algorithm {
return plugin.create();
}
}
}
Err(Error::UnsupportedAlgorithm {
algorithm: "Algorithm not registered".to_string(),
})
}
pub fn available_algorithms(&self) -> Vec<Algorithm> {
let mut algorithms = Vec::new();
#[cfg(feature = "std")]
{
if let Ok(constructors) = self.constructors.read() {
algorithms.extend(constructors.keys().copied());
}
}
#[cfg(not(feature = "std"))]
{
let constructors = self.constructors.read();
algorithms.extend(constructors.keys().copied());
}
#[cfg(feature = "std")]
{
if let Ok(plugins) = self.plugins.read() {
for plugin in plugins.iter() {
let algorithm = plugin.algorithm();
if !algorithms.contains(&algorithm) {
algorithms.push(algorithm);
}
}
}
}
#[cfg(not(feature = "std"))]
{
let plugins = self.plugins.read();
for plugin in plugins.iter() {
let algorithm = plugin.algorithm();
if !algorithms.contains(&algorithm) {
algorithms.push(algorithm);
}
}
}
algorithms.sort();
algorithms
}
pub fn is_available(&self, algorithm: Algorithm) -> bool {
#[cfg(feature = "std")]
{
if let Ok(constructors) = self.constructors.read() &&
constructors.contains_key(&algorithm)
{
return true;
}
}
#[cfg(not(feature = "std"))]
{
let constructors = self.constructors.read();
if constructors.contains_key(&algorithm) {
return true;
}
}
#[cfg(feature = "std")]
{
if let Ok(plugins) = self.plugins.read() {
for plugin in plugins.iter() {
if plugin.algorithm() == algorithm {
return true;
}
}
}
}
#[cfg(not(feature = "std"))]
{
let plugins = self.plugins.read();
for plugin in plugins.iter() {
if plugin.algorithm() == algorithm {
return true;
}
}
}
false
}
pub fn get_metadata(&self, algorithm: Algorithm) -> Option<&'static AeadMetadata> {
self.metadata.get(&algorithm).copied()
}
pub fn get_all_metadata(&self) -> Vec<&'static AeadMetadata> {
let available = self.available_algorithms();
available
.iter()
.filter_map(|&algorithm| self.get_metadata(algorithm))
.collect()
}
}
impl Default for AeadRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use lib_q_core::{
Aead,
AeadKey,
Nonce,
};
use super::*;
struct MockAead {
algorithm: Algorithm,
}
impl Aead for MockAead {
fn encrypt(
&self,
_key: &AeadKey,
_nonce: &Nonce,
_plaintext: &[u8],
_associated_data: Option<&[u8]>,
) -> Result<Vec<u8>> {
Ok(alloc::vec![1, 2, 3, 4])
}
fn decrypt(
&self,
_key: &AeadKey,
_nonce: &Nonce,
_ciphertext: &[u8],
_associated_data: Option<&[u8]>,
) -> Result<Vec<u8>> {
Ok(alloc::vec![5, 6, 7, 8])
}
}
impl AeadWithMetadata for MockAead {
fn metadata(&self) -> &'static AeadMetadata {
crate::metadata::get_metadata(self.algorithm).expect("Metadata not found")
}
fn supports_semantic_decrypt(&self) -> bool {
false
}
}
#[test]
fn test_registry_creation() {
let registry = AeadRegistry::new();
assert!(registry.available_algorithms().is_empty());
}
#[test]
fn test_algorithm_registration() {
let registry = AeadRegistry::new();
let result = registry.register_algorithm(Algorithm::Saturnin, || {
Ok(Box::new(MockAead {
algorithm: Algorithm::Saturnin,
}) as Box<dyn AeadWithMetadata>)
});
assert!(result.is_ok());
assert!(registry.is_available(Algorithm::Saturnin));
assert!(
registry
.available_algorithms()
.contains(&Algorithm::Saturnin)
);
}
#[test]
fn test_algorithm_creation() {
let registry = AeadRegistry::new();
registry
.register_algorithm(Algorithm::Saturnin, || {
Ok(Box::new(MockAead {
algorithm: Algorithm::Saturnin,
}) as Box<dyn AeadWithMetadata>)
})
.unwrap();
let aead = registry.create_aead(Algorithm::Saturnin);
assert!(aead.is_ok());
}
#[test]
fn test_invalid_algorithm_registration() {
let registry = AeadRegistry::new();
let result = registry.register_algorithm(Algorithm::MlKem512, || {
Ok(Box::new(MockAead {
algorithm: Algorithm::MlKem512,
}) as Box<dyn AeadWithMetadata>)
});
assert!(result.is_err());
if let Err(Error::InvalidAlgorithm { algorithm }) = result {
assert!(algorithm.contains("not an AEAD algorithm"));
} else {
panic!("Expected InvalidAlgorithm error");
}
}
#[test]
fn test_metadata_retrieval() {
let registry = AeadRegistry::new();
let metadata = registry.get_metadata(Algorithm::Saturnin);
assert!(metadata.is_some());
if let Some(meta) = metadata {
assert_eq!(meta.algorithm, Algorithm::Saturnin);
assert_eq!(meta.name, "Saturnin");
assert!(meta.key_size > 0);
assert!(meta.nonce_size > 0);
assert!(meta.tag_size > 0);
}
}
#[test]
fn test_unsupported_algorithm() {
let registry = AeadRegistry::new();
let result = registry.create_aead(Algorithm::Shake256Aead);
assert!(result.is_err());
if let Err(Error::UnsupportedAlgorithm { algorithm }) = result {
assert!(algorithm.contains("not registered"));
} else {
panic!("Expected UnsupportedAlgorithm error");
}
}
}