Skip to main content

lib_q_aead/
registry.rs

1//! AEAD Algorithm Registry
2//!
3//! This module provides a flexible registry system for AEAD algorithms that allows
4//! dynamic registration and creation of algorithm instances.
5
6use alloc::boxed::Box;
7use alloc::collections::BTreeMap;
8#[cfg(feature = "alloc")]
9use alloc::string::ToString;
10use alloc::vec::Vec;
11#[cfg(feature = "std")]
12use core::hash::Hasher;
13
14/// Custom hasher for Algorithm to ensure consistent hashing
15/// This provides deterministic hashing for Algorithm keys
16/// Available for future HashMap usage when security requirements allow
17#[cfg(feature = "std")]
18#[derive(Default)]
19#[allow(dead_code)] // Reserved for future HashMap implementation
20struct AlgorithmHasher {
21    state: u64,
22}
23
24#[cfg(feature = "std")]
25impl Hasher for AlgorithmHasher {
26    fn write(&mut self, bytes: &[u8]) {
27        for &byte in bytes {
28            self.state = self.state.wrapping_mul(31).wrapping_add(byte as u64);
29        }
30    }
31
32    fn finish(&self) -> u64 {
33        self.state
34    }
35}
36
37/// Type alias for algorithm storage
38/// Uses BTreeMap for consistent cross-platform behavior and security
39/// AlgorithmHasher is available for future HashMap usage if needed
40type AlgorithmHashMap = BTreeMap<Algorithm, AeadConstructor>;
41#[cfg(feature = "std")]
42use std::sync::RwLock;
43
44use lib_q_core::{
45    Algorithm,
46    AlgorithmCategory,
47    Error,
48    Result,
49};
50#[cfg(not(feature = "std"))]
51use spin::RwLock;
52
53use crate::AeadWithMetadata;
54use crate::metadata::AeadMetadata;
55use crate::plugin::AeadPlugin;
56
57/// Constructor function type for creating AEAD instances
58pub type AeadConstructor = Box<dyn Fn() -> Result<Box<dyn AeadWithMetadata>> + Send + Sync>;
59
60/// Registry for AEAD algorithms
61pub struct AeadRegistry {
62    constructors: RwLock<AlgorithmHashMap>,
63    plugins: RwLock<Vec<Box<dyn AeadPlugin>>>,
64    metadata: BTreeMap<Algorithm, &'static AeadMetadata>,
65}
66
67impl AeadRegistry {
68    /// Create a new AEAD registry
69    pub fn new() -> Self {
70        Self {
71            constructors: RwLock::new(AlgorithmHashMap::new()),
72            plugins: RwLock::new(Vec::new()),
73            metadata: Self::create_metadata_map(),
74        }
75    }
76
77    /// Create the metadata map for all known algorithms
78    fn create_metadata_map() -> BTreeMap<Algorithm, &'static AeadMetadata> {
79        let mut metadata = BTreeMap::new();
80        let known_algorithms = [
81            Algorithm::Saturnin,
82            Algorithm::Shake256Aead,
83            Algorithm::DuplexSpongeAead,
84            Algorithm::TweakAead,
85            Algorithm::RomulusN,
86            Algorithm::RomulusM,
87        ];
88
89        for algorithm in known_algorithms {
90            if let Some(algorithm_metadata) = crate::metadata::get_metadata(algorithm) {
91                metadata.insert(algorithm, algorithm_metadata);
92            }
93        }
94
95        metadata
96    }
97
98    /// Register an algorithm constructor
99    pub fn register_algorithm<F>(&self, algorithm: Algorithm, constructor: F) -> Result<()>
100    where
101        F: Fn() -> Result<Box<dyn AeadWithMetadata>> + Send + Sync + 'static,
102    {
103        // Validate algorithm category
104        if algorithm.category() != AlgorithmCategory::Aead {
105            return Err(Error::InvalidAlgorithm {
106                algorithm: "Algorithm is not an AEAD algorithm",
107            });
108        }
109
110        #[cfg(feature = "std")]
111        {
112            let mut constructors = self.constructors.write().map_err(|_| Error::InvalidState {
113                operation: "register_algorithm".to_string(),
114                reason: "Failed to acquire write lock".to_string(),
115            })?;
116            constructors.insert(algorithm, Box::new(constructor));
117        }
118        #[cfg(not(feature = "std"))]
119        {
120            let mut constructors = self.constructors.write();
121            constructors.insert(algorithm, Box::new(constructor));
122        }
123        Ok(())
124    }
125
126    /// Register a plugin
127    pub fn register_plugin(&self, plugin: Box<dyn AeadPlugin>) -> Result<()> {
128        #[cfg(feature = "std")]
129        {
130            let mut plugins = self.plugins.write().map_err(|_| Error::InvalidState {
131                operation: "register_plugin".to_string(),
132                reason: "Failed to acquire write lock".to_string(),
133            })?;
134            plugins.push(plugin);
135        }
136        #[cfg(not(feature = "std"))]
137        {
138            let mut plugins = self.plugins.write();
139            plugins.push(plugin);
140        }
141        Ok(())
142    }
143
144    /// Create an AEAD instance for the specified algorithm
145    pub fn create_aead(&self, algorithm: Algorithm) -> Result<Box<dyn AeadWithMetadata>> {
146        // First try direct constructors
147        #[cfg(feature = "std")]
148        {
149            let constructors = self.constructors.read().map_err(|_| Error::InvalidState {
150                operation: "create_aead".to_string(),
151                reason: "Failed to acquire read lock".to_string(),
152            })?;
153            if let Some(constructor) = constructors.get(&algorithm) {
154                return constructor();
155            }
156        }
157        #[cfg(not(feature = "std"))]
158        {
159            let constructors = self.constructors.read();
160            if let Some(constructor) = constructors.get(&algorithm) {
161                return constructor();
162            }
163        }
164
165        // Then try plugins
166        #[cfg(feature = "std")]
167        {
168            let plugins = self.plugins.read().map_err(|_| Error::InvalidState {
169                operation: "create_aead".to_string(),
170                reason: "Failed to acquire read lock".to_string(),
171            })?;
172            for plugin in plugins.iter() {
173                if plugin.algorithm() == algorithm {
174                    return plugin.create();
175                }
176            }
177        }
178        #[cfg(not(feature = "std"))]
179        {
180            let plugins = self.plugins.read();
181            for plugin in plugins.iter() {
182                if plugin.algorithm() == algorithm {
183                    return plugin.create();
184                }
185            }
186        }
187
188        Err(Error::UnsupportedAlgorithm {
189            algorithm: "Algorithm not registered".to_string(),
190        })
191    }
192
193    /// Get available algorithms
194    pub fn available_algorithms(&self) -> Vec<Algorithm> {
195        let mut algorithms = Vec::new();
196
197        // Add algorithms from constructors
198        #[cfg(feature = "std")]
199        {
200            if let Ok(constructors) = self.constructors.read() {
201                algorithms.extend(constructors.keys().copied());
202            }
203        }
204        #[cfg(not(feature = "std"))]
205        {
206            let constructors = self.constructors.read();
207            algorithms.extend(constructors.keys().copied());
208        }
209
210        // Add algorithms from plugins
211        #[cfg(feature = "std")]
212        {
213            if let Ok(plugins) = self.plugins.read() {
214                for plugin in plugins.iter() {
215                    let algorithm = plugin.algorithm();
216                    if !algorithms.contains(&algorithm) {
217                        algorithms.push(algorithm);
218                    }
219                }
220            }
221        }
222        #[cfg(not(feature = "std"))]
223        {
224            let plugins = self.plugins.read();
225            for plugin in plugins.iter() {
226                let algorithm = plugin.algorithm();
227                if !algorithms.contains(&algorithm) {
228                    algorithms.push(algorithm);
229                }
230            }
231        }
232
233        algorithms.sort();
234        algorithms
235    }
236
237    /// Check if an algorithm is available
238    pub fn is_available(&self, algorithm: Algorithm) -> bool {
239        // Check constructors
240        #[cfg(feature = "std")]
241        {
242            if let Ok(constructors) = self.constructors.read() &&
243                constructors.contains_key(&algorithm)
244            {
245                return true;
246            }
247        }
248        #[cfg(not(feature = "std"))]
249        {
250            let constructors = self.constructors.read();
251            if constructors.contains_key(&algorithm) {
252                return true;
253            }
254        }
255
256        // Check plugins
257        #[cfg(feature = "std")]
258        {
259            if let Ok(plugins) = self.plugins.read() {
260                for plugin in plugins.iter() {
261                    if plugin.algorithm() == algorithm {
262                        return true;
263                    }
264                }
265            }
266        }
267        #[cfg(not(feature = "std"))]
268        {
269            let plugins = self.plugins.read();
270            for plugin in plugins.iter() {
271                if plugin.algorithm() == algorithm {
272                    return true;
273                }
274            }
275        }
276
277        false
278    }
279
280    /// Get algorithm metadata
281    pub fn get_metadata(&self, algorithm: Algorithm) -> Option<&'static AeadMetadata> {
282        self.metadata.get(&algorithm).copied()
283    }
284
285    /// Get all registered algorithms with their metadata
286    pub fn get_all_metadata(&self) -> Vec<&'static AeadMetadata> {
287        let available = self.available_algorithms();
288        available
289            .iter()
290            .filter_map(|&algorithm| self.get_metadata(algorithm))
291            .collect()
292    }
293}
294
295impl Default for AeadRegistry {
296    fn default() -> Self {
297        Self::new()
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use lib_q_core::{
304        Aead,
305        AeadKey,
306        Nonce,
307    };
308
309    use super::*;
310
311    // Test-only stub: Layer A + metadata only (no `AeadDecryptSemantic`); mirrors `plugin` tests.
312    struct MockAead {
313        algorithm: Algorithm,
314    }
315
316    impl Aead for MockAead {
317        fn encrypt(
318            &self,
319            _key: &AeadKey,
320            _nonce: &Nonce,
321            _plaintext: &[u8],
322            _associated_data: Option<&[u8]>,
323        ) -> Result<Vec<u8>> {
324            Ok(alloc::vec![1, 2, 3, 4])
325        }
326
327        fn decrypt(
328            &self,
329            _key: &AeadKey,
330            _nonce: &Nonce,
331            _ciphertext: &[u8],
332            _associated_data: Option<&[u8]>,
333        ) -> Result<Vec<u8>> {
334            Ok(alloc::vec![5, 6, 7, 8])
335        }
336    }
337
338    impl AeadWithMetadata for MockAead {
339        fn metadata(&self) -> &'static AeadMetadata {
340            crate::metadata::get_metadata(self.algorithm).expect("Metadata not found")
341        }
342
343        fn supports_semantic_decrypt(&self) -> bool {
344            false
345        }
346    }
347
348    #[test]
349    fn test_registry_creation() {
350        let registry = AeadRegistry::new();
351        assert!(registry.available_algorithms().is_empty());
352    }
353
354    #[test]
355    fn test_algorithm_registration() {
356        let registry = AeadRegistry::new();
357
358        let result = registry.register_algorithm(Algorithm::Saturnin, || {
359            Ok(Box::new(MockAead {
360                algorithm: Algorithm::Saturnin,
361            }) as Box<dyn AeadWithMetadata>)
362        });
363
364        assert!(result.is_ok());
365        assert!(registry.is_available(Algorithm::Saturnin));
366        assert!(
367            registry
368                .available_algorithms()
369                .contains(&Algorithm::Saturnin)
370        );
371    }
372
373    #[test]
374    fn test_algorithm_creation() {
375        let registry = AeadRegistry::new();
376
377        registry
378            .register_algorithm(Algorithm::Saturnin, || {
379                Ok(Box::new(MockAead {
380                    algorithm: Algorithm::Saturnin,
381                }) as Box<dyn AeadWithMetadata>)
382            })
383            .unwrap();
384
385        let aead = registry.create_aead(Algorithm::Saturnin);
386        assert!(aead.is_ok());
387    }
388
389    #[test]
390    fn test_invalid_algorithm_registration() {
391        let registry = AeadRegistry::new();
392
393        let result = registry.register_algorithm(Algorithm::MlKem512, || {
394            Ok(Box::new(MockAead {
395                algorithm: Algorithm::MlKem512,
396            }) as Box<dyn AeadWithMetadata>)
397        });
398
399        assert!(result.is_err());
400        if let Err(Error::InvalidAlgorithm { algorithm }) = result {
401            assert!(algorithm.contains("not an AEAD algorithm"));
402        } else {
403            panic!("Expected InvalidAlgorithm error");
404        }
405    }
406
407    #[test]
408    fn test_metadata_retrieval() {
409        let registry = AeadRegistry::new();
410
411        let metadata = registry.get_metadata(Algorithm::Saturnin);
412        assert!(metadata.is_some());
413
414        if let Some(meta) = metadata {
415            assert_eq!(meta.algorithm, Algorithm::Saturnin);
416            assert_eq!(meta.name, "Saturnin");
417            assert!(meta.key_size > 0);
418            assert!(meta.nonce_size > 0);
419            assert!(meta.tag_size > 0);
420        }
421    }
422
423    #[test]
424    fn test_unsupported_algorithm() {
425        let registry = AeadRegistry::new();
426
427        let result = registry.create_aead(Algorithm::Shake256Aead);
428        assert!(result.is_err());
429
430        if let Err(Error::UnsupportedAlgorithm { algorithm }) = result {
431            assert!(algorithm.contains("not registered"));
432        } else {
433            panic!("Expected UnsupportedAlgorithm error");
434        }
435    }
436}