1use alloc::boxed::Box;
7use alloc::collections::BTreeMap;
8#[cfg(feature = "alloc")]
9use alloc::string::ToString;
10use alloc::vec::Vec;
11#[cfg(feature = "std")]
12use core::hash::Hasher;
13
14#[cfg(feature = "std")]
18#[derive(Default)]
19#[allow(dead_code)] struct AlgorithmHasher {
21 state: u64,
22}
23
24#[cfg(feature = "std")]
25impl Hasher for AlgorithmHasher {
26 fn write(&mut self, bytes: &[u8]) {
27 for &byte in bytes {
28 self.state = self.state.wrapping_mul(31).wrapping_add(byte as u64);
29 }
30 }
31
32 fn finish(&self) -> u64 {
33 self.state
34 }
35}
36
37type AlgorithmHashMap = BTreeMap<Algorithm, AeadConstructor>;
41#[cfg(feature = "std")]
42use std::sync::RwLock;
43
44use lib_q_core::{
45 Algorithm,
46 AlgorithmCategory,
47 Error,
48 Result,
49};
50#[cfg(not(feature = "std"))]
51use spin::RwLock;
52
53use crate::AeadWithMetadata;
54use crate::metadata::AeadMetadata;
55use crate::plugin::AeadPlugin;
56
57pub type AeadConstructor = Box<dyn Fn() -> Result<Box<dyn AeadWithMetadata>> + Send + Sync>;
59
60pub struct AeadRegistry {
62 constructors: RwLock<AlgorithmHashMap>,
63 plugins: RwLock<Vec<Box<dyn AeadPlugin>>>,
64 metadata: BTreeMap<Algorithm, &'static AeadMetadata>,
65}
66
67impl AeadRegistry {
68 pub fn new() -> Self {
70 Self {
71 constructors: RwLock::new(AlgorithmHashMap::new()),
72 plugins: RwLock::new(Vec::new()),
73 metadata: Self::create_metadata_map(),
74 }
75 }
76
77 fn create_metadata_map() -> BTreeMap<Algorithm, &'static AeadMetadata> {
79 let mut metadata = BTreeMap::new();
80 let known_algorithms = [
81 Algorithm::Saturnin,
82 Algorithm::Shake256Aead,
83 Algorithm::DuplexSpongeAead,
84 Algorithm::TweakAead,
85 Algorithm::RomulusN,
86 Algorithm::RomulusM,
87 ];
88
89 for algorithm in known_algorithms {
90 if let Some(algorithm_metadata) = crate::metadata::get_metadata(algorithm) {
91 metadata.insert(algorithm, algorithm_metadata);
92 }
93 }
94
95 metadata
96 }
97
98 pub fn register_algorithm<F>(&self, algorithm: Algorithm, constructor: F) -> Result<()>
100 where
101 F: Fn() -> Result<Box<dyn AeadWithMetadata>> + Send + Sync + 'static,
102 {
103 if algorithm.category() != AlgorithmCategory::Aead {
105 return Err(Error::InvalidAlgorithm {
106 algorithm: "Algorithm is not an AEAD algorithm",
107 });
108 }
109
110 #[cfg(feature = "std")]
111 {
112 let mut constructors = self.constructors.write().map_err(|_| Error::InvalidState {
113 operation: "register_algorithm".to_string(),
114 reason: "Failed to acquire write lock".to_string(),
115 })?;
116 constructors.insert(algorithm, Box::new(constructor));
117 }
118 #[cfg(not(feature = "std"))]
119 {
120 let mut constructors = self.constructors.write();
121 constructors.insert(algorithm, Box::new(constructor));
122 }
123 Ok(())
124 }
125
126 pub fn register_plugin(&self, plugin: Box<dyn AeadPlugin>) -> Result<()> {
128 #[cfg(feature = "std")]
129 {
130 let mut plugins = self.plugins.write().map_err(|_| Error::InvalidState {
131 operation: "register_plugin".to_string(),
132 reason: "Failed to acquire write lock".to_string(),
133 })?;
134 plugins.push(plugin);
135 }
136 #[cfg(not(feature = "std"))]
137 {
138 let mut plugins = self.plugins.write();
139 plugins.push(plugin);
140 }
141 Ok(())
142 }
143
144 pub fn create_aead(&self, algorithm: Algorithm) -> Result<Box<dyn AeadWithMetadata>> {
146 #[cfg(feature = "std")]
148 {
149 let constructors = self.constructors.read().map_err(|_| Error::InvalidState {
150 operation: "create_aead".to_string(),
151 reason: "Failed to acquire read lock".to_string(),
152 })?;
153 if let Some(constructor) = constructors.get(&algorithm) {
154 return constructor();
155 }
156 }
157 #[cfg(not(feature = "std"))]
158 {
159 let constructors = self.constructors.read();
160 if let Some(constructor) = constructors.get(&algorithm) {
161 return constructor();
162 }
163 }
164
165 #[cfg(feature = "std")]
167 {
168 let plugins = self.plugins.read().map_err(|_| Error::InvalidState {
169 operation: "create_aead".to_string(),
170 reason: "Failed to acquire read lock".to_string(),
171 })?;
172 for plugin in plugins.iter() {
173 if plugin.algorithm() == algorithm {
174 return plugin.create();
175 }
176 }
177 }
178 #[cfg(not(feature = "std"))]
179 {
180 let plugins = self.plugins.read();
181 for plugin in plugins.iter() {
182 if plugin.algorithm() == algorithm {
183 return plugin.create();
184 }
185 }
186 }
187
188 Err(Error::UnsupportedAlgorithm {
189 algorithm: "Algorithm not registered".to_string(),
190 })
191 }
192
193 pub fn available_algorithms(&self) -> Vec<Algorithm> {
195 let mut algorithms = Vec::new();
196
197 #[cfg(feature = "std")]
199 {
200 if let Ok(constructors) = self.constructors.read() {
201 algorithms.extend(constructors.keys().copied());
202 }
203 }
204 #[cfg(not(feature = "std"))]
205 {
206 let constructors = self.constructors.read();
207 algorithms.extend(constructors.keys().copied());
208 }
209
210 #[cfg(feature = "std")]
212 {
213 if let Ok(plugins) = self.plugins.read() {
214 for plugin in plugins.iter() {
215 let algorithm = plugin.algorithm();
216 if !algorithms.contains(&algorithm) {
217 algorithms.push(algorithm);
218 }
219 }
220 }
221 }
222 #[cfg(not(feature = "std"))]
223 {
224 let plugins = self.plugins.read();
225 for plugin in plugins.iter() {
226 let algorithm = plugin.algorithm();
227 if !algorithms.contains(&algorithm) {
228 algorithms.push(algorithm);
229 }
230 }
231 }
232
233 algorithms.sort();
234 algorithms
235 }
236
237 pub fn is_available(&self, algorithm: Algorithm) -> bool {
239 #[cfg(feature = "std")]
241 {
242 if let Ok(constructors) = self.constructors.read() &&
243 constructors.contains_key(&algorithm)
244 {
245 return true;
246 }
247 }
248 #[cfg(not(feature = "std"))]
249 {
250 let constructors = self.constructors.read();
251 if constructors.contains_key(&algorithm) {
252 return true;
253 }
254 }
255
256 #[cfg(feature = "std")]
258 {
259 if let Ok(plugins) = self.plugins.read() {
260 for plugin in plugins.iter() {
261 if plugin.algorithm() == algorithm {
262 return true;
263 }
264 }
265 }
266 }
267 #[cfg(not(feature = "std"))]
268 {
269 let plugins = self.plugins.read();
270 for plugin in plugins.iter() {
271 if plugin.algorithm() == algorithm {
272 return true;
273 }
274 }
275 }
276
277 false
278 }
279
280 pub fn get_metadata(&self, algorithm: Algorithm) -> Option<&'static AeadMetadata> {
282 self.metadata.get(&algorithm).copied()
283 }
284
285 pub fn get_all_metadata(&self) -> Vec<&'static AeadMetadata> {
287 let available = self.available_algorithms();
288 available
289 .iter()
290 .filter_map(|&algorithm| self.get_metadata(algorithm))
291 .collect()
292 }
293}
294
295impl Default for AeadRegistry {
296 fn default() -> Self {
297 Self::new()
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use lib_q_core::{
304 Aead,
305 AeadKey,
306 Nonce,
307 };
308
309 use super::*;
310
311 struct MockAead {
313 algorithm: Algorithm,
314 }
315
316 impl Aead for MockAead {
317 fn encrypt(
318 &self,
319 _key: &AeadKey,
320 _nonce: &Nonce,
321 _plaintext: &[u8],
322 _associated_data: Option<&[u8]>,
323 ) -> Result<Vec<u8>> {
324 Ok(alloc::vec![1, 2, 3, 4])
325 }
326
327 fn decrypt(
328 &self,
329 _key: &AeadKey,
330 _nonce: &Nonce,
331 _ciphertext: &[u8],
332 _associated_data: Option<&[u8]>,
333 ) -> Result<Vec<u8>> {
334 Ok(alloc::vec![5, 6, 7, 8])
335 }
336 }
337
338 impl AeadWithMetadata for MockAead {
339 fn metadata(&self) -> &'static AeadMetadata {
340 crate::metadata::get_metadata(self.algorithm).expect("Metadata not found")
341 }
342
343 fn supports_semantic_decrypt(&self) -> bool {
344 false
345 }
346 }
347
348 #[test]
349 fn test_registry_creation() {
350 let registry = AeadRegistry::new();
351 assert!(registry.available_algorithms().is_empty());
352 }
353
354 #[test]
355 fn test_algorithm_registration() {
356 let registry = AeadRegistry::new();
357
358 let result = registry.register_algorithm(Algorithm::Saturnin, || {
359 Ok(Box::new(MockAead {
360 algorithm: Algorithm::Saturnin,
361 }) as Box<dyn AeadWithMetadata>)
362 });
363
364 assert!(result.is_ok());
365 assert!(registry.is_available(Algorithm::Saturnin));
366 assert!(
367 registry
368 .available_algorithms()
369 .contains(&Algorithm::Saturnin)
370 );
371 }
372
373 #[test]
374 fn test_algorithm_creation() {
375 let registry = AeadRegistry::new();
376
377 registry
378 .register_algorithm(Algorithm::Saturnin, || {
379 Ok(Box::new(MockAead {
380 algorithm: Algorithm::Saturnin,
381 }) as Box<dyn AeadWithMetadata>)
382 })
383 .unwrap();
384
385 let aead = registry.create_aead(Algorithm::Saturnin);
386 assert!(aead.is_ok());
387 }
388
389 #[test]
390 fn test_invalid_algorithm_registration() {
391 let registry = AeadRegistry::new();
392
393 let result = registry.register_algorithm(Algorithm::MlKem512, || {
394 Ok(Box::new(MockAead {
395 algorithm: Algorithm::MlKem512,
396 }) as Box<dyn AeadWithMetadata>)
397 });
398
399 assert!(result.is_err());
400 if let Err(Error::InvalidAlgorithm { algorithm }) = result {
401 assert!(algorithm.contains("not an AEAD algorithm"));
402 } else {
403 panic!("Expected InvalidAlgorithm error");
404 }
405 }
406
407 #[test]
408 fn test_metadata_retrieval() {
409 let registry = AeadRegistry::new();
410
411 let metadata = registry.get_metadata(Algorithm::Saturnin);
412 assert!(metadata.is_some());
413
414 if let Some(meta) = metadata {
415 assert_eq!(meta.algorithm, Algorithm::Saturnin);
416 assert_eq!(meta.name, "Saturnin");
417 assert!(meta.key_size > 0);
418 assert!(meta.nonce_size > 0);
419 assert!(meta.tag_size > 0);
420 }
421 }
422
423 #[test]
424 fn test_unsupported_algorithm() {
425 let registry = AeadRegistry::new();
426
427 let result = registry.create_aead(Algorithm::Shake256Aead);
428 assert!(result.is_err());
429
430 if let Err(Error::UnsupportedAlgorithm { algorithm }) = result {
431 assert!(algorithm.contains("not registered"));
432 } else {
433 panic!("Expected UnsupportedAlgorithm error");
434 }
435 }
436}