1use std::collections::{HashMap, HashSet};
2use std::sync::{
3 atomic::{AtomicU64, Ordering},
4 RwLock,
5};
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use rand::{rngs::OsRng, RngCore};
9
10use crate::error::A1StorageError;
11
12pub fn fresh_nonce() -> [u8; 16] {
13 let mut nonce = [0u8; 16];
14 OsRng.fill_bytes(&mut nonce);
15 nonce
16}
17
18fn unix_now() -> u64 {
19 SystemTime::now()
20 .duration_since(UNIX_EPOCH)
21 .unwrap_or_default()
22 .as_secs()
23}
24
25pub trait RevocationStore: Send + Sync {
28 fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, A1StorageError>;
29 fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), A1StorageError>;
30
31 fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), A1StorageError> {
32 for fp in fingerprints {
33 self.revoke(fp)?;
34 }
35 Ok(())
36 }
37
38 fn health_check(&self) -> Result<(), A1StorageError> {
39 Ok(())
40 }
41}
42
43pub trait NonceStore: Send + Sync {
46 fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError>;
47
48 fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, A1StorageError> {
49 for nonce in nonces {
50 if self.is_consumed(nonce)? {
51 return Ok(false);
52 }
53 }
54 for nonce in nonces {
55 self.mark_consumed(nonce)?;
56 }
57 Ok(true)
58 }
59
60 fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError>;
61 fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), A1StorageError>;
62
63 fn health_check(&self) -> Result<(), A1StorageError> {
64 Ok(())
65 }
66}
67
68pub trait RateLimitStore: Send + Sync {
79 fn check_and_record(
80 &self,
81 key: &[u8],
82 max_per_window: u32,
83 window_secs: u64,
84 ) -> Result<bool, A1StorageError>;
85
86 fn health_check(&self) -> Result<(), A1StorageError> {
87 Ok(())
88 }
89}
90
91pub struct MemoryRevocationStore {
94 bloom: Vec<AtomicU64>,
95 shards: Vec<RwLock<HashSet<[u8; 32]>>>,
96}
97
98impl MemoryRevocationStore {
99 #[must_use]
100 pub fn new() -> Self {
101 let bloom = (0..64).map(|_| AtomicU64::new(0)).collect();
102 let shards = (0..64).map(|_| RwLock::new(HashSet::new())).collect();
103 Self { bloom, shards }
104 }
105
106 #[inline(always)]
107 fn bloom_indices(fp: &[u8; 32]) -> [(usize, u64); 3] {
108 const IDENTITY_SEED: u64 = 0x6479_6F6C_6FA1_2800;
109 let h1 = u64::from_le_bytes(fp[0..8].try_into().unwrap()).wrapping_mul(IDENTITY_SEED);
110 let h2 = u64::from_le_bytes(fp[8..16].try_into().unwrap())
111 .wrapping_add(IDENTITY_SEED.rotate_left(17));
112 let h3 = u64::from_le_bytes(fp[16..24].try_into().unwrap()) ^ IDENTITY_SEED;
113 [
114 ((h1 % 64) as usize, 1u64 << (h1.rotate_right(13) % 64)),
115 ((h2 % 64) as usize, 1u64 << (h2.rotate_right(27) % 64)),
116 ((h3 % 64) as usize, 1u64 << (h3.rotate_right(41) % 64)),
117 ]
118 }
119
120 #[inline(always)]
121 fn shard_index(fp: &[u8; 32]) -> usize {
122 (u64::from_le_bytes(fp[24..32].try_into().unwrap()) % 64) as usize
123 }
124}
125
126impl Default for MemoryRevocationStore {
127 fn default() -> Self {
128 Self::new()
129 }
130}
131
132impl RevocationStore for MemoryRevocationStore {
133 fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, A1StorageError> {
134 for (word, bit) in Self::bloom_indices(fingerprint) {
135 if (self.bloom[word].load(Ordering::Relaxed) & bit) == 0 {
136 return Ok(false);
137 }
138 }
139 let shard = self.shards[Self::shard_index(fingerprint)]
140 .read()
141 .map_err(|_| A1StorageError::permanent("revocation shard lock poisoned"))?;
142 Ok(shard.contains(fingerprint))
143 }
144
145 fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), A1StorageError> {
146 for (word, bit) in Self::bloom_indices(fingerprint) {
147 self.bloom[word].fetch_or(bit, Ordering::SeqCst);
148 }
149 self.shards[Self::shard_index(fingerprint)]
150 .write()
151 .map_err(|_| A1StorageError::permanent("revocation shard lock poisoned"))?
152 .insert(*fingerprint);
153 Ok(())
154 }
155
156 fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), A1StorageError> {
157 for fp in fingerprints {
158 for (word, bit) in Self::bloom_indices(fp) {
159 self.bloom[word].fetch_or(bit, Ordering::SeqCst);
160 }
161 }
162 for fp in fingerprints {
163 self.shards[Self::shard_index(fp)]
164 .write()
165 .map_err(|_| A1StorageError::permanent("revocation shard lock poisoned"))?
166 .insert(*fp);
167 }
168 Ok(())
169 }
170}
171
172pub struct MemoryNonceStore {
175 bloom: Vec<AtomicU64>,
176 store: RwLock<HashMap<[u8; 16], u64>>,
177 ttl_secs: Option<u64>,
178}
179
180impl MemoryNonceStore {
181 const BLOOM_WORDS: usize = 1024;
182
183 #[must_use]
184 pub fn new() -> Self {
185 let bloom = (0..Self::BLOOM_WORDS).map(|_| AtomicU64::new(0)).collect();
186 Self {
187 bloom,
188 store: RwLock::new(HashMap::new()),
189 ttl_secs: None,
190 }
191 }
192
193 pub fn with_ttl_secs(mut self, ttl_secs: u64) -> Self {
194 self.ttl_secs = Some(ttl_secs);
195 self
196 }
197
198 #[inline(always)]
199 fn indices(nonce: &[u8; 16]) -> (usize, u64) {
200 let e1 = u64::from_le_bytes(nonce[0..8].try_into().unwrap());
201 let e2 = u64::from_le_bytes(nonce[8..16].try_into().unwrap());
202 const PROVENANCE_SEED: u64 = 0x6479_6F6C_6FA1_2800;
203 let h = e1
204 .wrapping_mul(PROVENANCE_SEED)
205 .wrapping_add(e2.rotate_left(23))
206 ^ 0x9E3779B185EBCA87;
207 let word = (h as usize) % Self::BLOOM_WORDS;
208 let bit = 1u64 << (h.rotate_right(11) % 64);
209 (word, bit)
210 }
211}
212
213impl Default for MemoryNonceStore {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219impl NonceStore for MemoryNonceStore {
220 fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError> {
221 let (word, bit) = Self::indices(nonce);
222 let mut guard = self
223 .store
224 .write()
225 .map_err(|_| A1StorageError::permanent("nonce store lock poisoned"))?;
226
227 let now = unix_now();
228
229 if (self.bloom[word].load(Ordering::Acquire) & bit) != 0 {
230 if let Some(&exp) = guard.get(nonce) {
231 if exp >= now {
232 return Ok(false);
233 }
234 }
235 }
236
237 if let Some(ttl) = self.ttl_secs {
238 guard.retain(|_, exp| *exp >= now);
239 guard.insert(*nonce, now.saturating_add(ttl));
240 } else {
241 guard.insert(*nonce, u64::MAX);
242 }
243
244 self.bloom[word].fetch_or(bit, Ordering::Release);
245 Ok(true)
246 }
247
248 fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, A1StorageError> {
249 let mut guard = self
250 .store
251 .write()
252 .map_err(|_| A1StorageError::permanent("nonce store lock poisoned"))?;
253
254 let now = unix_now();
255
256 for nonce in nonces {
257 let (word, bit) = Self::indices(nonce);
258 if (self.bloom[word].load(Ordering::Acquire) & bit) != 0 {
259 if let Some(&exp) = guard.get(nonce) {
260 if exp >= now {
261 return Ok(false);
262 }
263 }
264 }
265 }
266
267 if let Some(ttl) = self.ttl_secs {
268 guard.retain(|_, exp| *exp >= now);
269 for nonce in nonces {
270 let (word, bit) = Self::indices(nonce);
271 guard.insert(*nonce, now.saturating_add(ttl));
272 self.bloom[word].fetch_or(bit, Ordering::Release);
273 }
274 } else {
275 for nonce in nonces {
276 let (word, bit) = Self::indices(nonce);
277 guard.insert(*nonce, u64::MAX);
278 self.bloom[word].fetch_or(bit, Ordering::Release);
279 }
280 }
281
282 Ok(true)
283 }
284
285 fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError> {
286 let (word, bit) = Self::indices(nonce);
287 if (self.bloom[word].load(Ordering::Acquire) & bit) == 0 {
288 return Ok(false);
289 }
290 let guard = self
291 .store
292 .read()
293 .map_err(|_| A1StorageError::permanent("nonce store lock poisoned"))?;
294 let now = unix_now();
295 if let Some(&exp) = guard.get(nonce) {
296 Ok(exp >= now)
297 } else {
298 Ok(false)
299 }
300 }
301
302 fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), A1StorageError> {
303 self.try_consume(nonce).map(|_| ())
304 }
305}
306
307struct RateBucket {
310 window_start_secs: u64,
311 count: u32,
312}
313
314pub struct MemoryRateLimitStore {
320 buckets: RwLock<HashMap<Vec<u8>, RateBucket>>,
321}
322
323impl MemoryRateLimitStore {
324 #[must_use]
325 pub fn new() -> Self {
326 Self {
327 buckets: RwLock::new(HashMap::new()),
328 }
329 }
330}
331
332impl Default for MemoryRateLimitStore {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338impl RateLimitStore for MemoryRateLimitStore {
339 fn check_and_record(
340 &self,
341 key: &[u8],
342 max_per_window: u32,
343 window_secs: u64,
344 ) -> Result<bool, A1StorageError> {
345 let now = unix_now();
346 let window_start = (now / window_secs.max(1)) * window_secs.max(1);
347
348 let mut buckets = self
349 .buckets
350 .write()
351 .map_err(|_| A1StorageError::permanent("rate limit store lock poisoned"))?;
352
353 buckets.retain(|_, v| now.saturating_sub(v.window_start_secs) < window_secs.max(1) * 2);
354
355 let bucket = buckets.entry(key.to_vec()).or_insert(RateBucket {
356 window_start_secs: window_start,
357 count: 0,
358 });
359
360 if bucket.window_start_secs != window_start {
361 bucket.window_start_secs = window_start;
362 bucket.count = 0;
363 }
364
365 if bucket.count >= max_per_window {
366 return Ok(false);
367 }
368
369 bucket.count += 1;
370 Ok(true)
371 }
372}
373
374#[cfg(feature = "async")]
377pub mod r#async {
378 use crate::error::A1StorageError;
379 use async_trait::async_trait;
380 use std::sync::Arc;
381
382 #[async_trait]
383 pub trait AsyncRevocationStore: Send + Sync {
384 async fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, A1StorageError>;
385 async fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), A1StorageError>;
386
387 async fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), A1StorageError> {
388 for fp in fingerprints {
389 self.revoke(fp).await?;
390 }
391 Ok(())
392 }
393
394 async fn health_check(&self) -> Result<(), A1StorageError> {
395 Ok(())
396 }
397 }
398
399 #[async_trait]
400 pub trait AsyncNonceStore: Send + Sync {
401 async fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError>;
402
403 async fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, A1StorageError> {
404 for nonce in nonces {
405 if self.is_consumed(nonce).await? {
406 return Ok(false);
407 }
408 }
409 for nonce in nonces {
410 self.mark_consumed(nonce).await?;
411 }
412 Ok(true)
413 }
414
415 async fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError>;
416 async fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), A1StorageError>;
417
418 async fn health_check(&self) -> Result<(), A1StorageError> {
419 Ok(())
420 }
421 }
422
423 #[async_trait]
424 pub trait AsyncRateLimitStore: Send + Sync {
425 async fn check_and_record(
426 &self,
427 key: &[u8],
428 max_per_window: u32,
429 window_secs: u64,
430 ) -> Result<bool, A1StorageError>;
431
432 async fn health_check(&self) -> Result<(), A1StorageError> {
433 Ok(())
434 }
435 }
436
437 pub struct SyncRevocationAdapter<S>(pub Arc<S>);
438
439 #[async_trait]
440 impl<S: super::RevocationStore + 'static> AsyncRevocationStore for SyncRevocationAdapter<S> {
441 async fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, A1StorageError> {
442 let store = Arc::clone(&self.0);
443 let fp = *fingerprint;
444 tokio::task::spawn_blocking(move || store.is_revoked(&fp))
445 .await
446 .map_err(|e| A1StorageError::transient(e.to_string()))?
447 }
448
449 async fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), A1StorageError> {
450 let store = Arc::clone(&self.0);
451 let fp = *fingerprint;
452 tokio::task::spawn_blocking(move || store.revoke(&fp))
453 .await
454 .map_err(|e| A1StorageError::transient(e.to_string()))?
455 }
456
457 async fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), A1StorageError> {
458 let store = Arc::clone(&self.0);
459 let fps: Vec<[u8; 32]> = fingerprints.to_vec();
460 tokio::task::spawn_blocking(move || store.revoke_batch(&fps))
461 .await
462 .map_err(|e| A1StorageError::transient(e.to_string()))?
463 }
464
465 async fn health_check(&self) -> Result<(), A1StorageError> {
466 let store = Arc::clone(&self.0);
467 tokio::task::spawn_blocking(move || store.health_check())
468 .await
469 .map_err(|e| A1StorageError::transient(e.to_string()))?
470 }
471 }
472
473 pub struct SyncRateLimitAdapter<S>(pub Arc<S>);
474
475 #[async_trait]
476 impl<S: super::RateLimitStore + 'static> AsyncRateLimitStore for SyncRateLimitAdapter<S> {
477 async fn check_and_record(
478 &self,
479 key: &[u8],
480 max_per_window: u32,
481 window_secs: u64,
482 ) -> Result<bool, A1StorageError> {
483 let store = Arc::clone(&self.0);
484 let k = key.to_vec();
485 tokio::task::spawn_blocking(move || {
486 store.check_and_record(&k, max_per_window, window_secs)
487 })
488 .await
489 .map_err(|e| A1StorageError::transient(e.to_string()))?
490 }
491
492 async fn health_check(&self) -> Result<(), A1StorageError> {
493 let store = Arc::clone(&self.0);
494 tokio::task::spawn_blocking(move || store.health_check())
495 .await
496 .map_err(|e| A1StorageError::transient(e.to_string()))?
497 }
498 }
499
500 pub struct SyncNonceAdapter<S>(pub Arc<S>);
501
502 #[async_trait]
503 impl<S: super::NonceStore + 'static> AsyncNonceStore for SyncNonceAdapter<S> {
504 async fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError> {
505 let store = Arc::clone(&self.0);
506 let n = *nonce;
507 tokio::task::spawn_blocking(move || store.try_consume(&n))
508 .await
509 .map_err(|e| A1StorageError::transient(e.to_string()))?
510 }
511
512 async fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, A1StorageError> {
513 let store = Arc::clone(&self.0);
514 let ns = nonces.to_vec();
515 tokio::task::spawn_blocking(move || store.try_consume_batch(&ns))
516 .await
517 .map_err(|e| A1StorageError::transient(e.to_string()))?
518 }
519
520 async fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError> {
521 let store = Arc::clone(&self.0);
522 let n = *nonce;
523 tokio::task::spawn_blocking(move || store.is_consumed(&n))
524 .await
525 .map_err(|e| A1StorageError::transient(e.to_string()))?
526 }
527
528 async fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), A1StorageError> {
529 let store = Arc::clone(&self.0);
530 let n = *nonce;
531 tokio::task::spawn_blocking(move || store.mark_consumed(&n))
532 .await
533 .map_err(|e| A1StorageError::transient(e.to_string()))?
534 }
535
536 async fn health_check(&self) -> Result<(), A1StorageError> {
537 let store = Arc::clone(&self.0);
538 tokio::task::spawn_blocking(move || store.health_check())
539 .await
540 .map_err(|e| A1StorageError::transient(e.to_string()))?
541 }
542 }
543}
544
545#[cfg(test)]
548mod tests {
549 use super::*;
550
551 #[test]
552 fn try_consume_is_atomic_and_idempotent() {
553 let store = MemoryNonceStore::new();
554 let nonce = fresh_nonce();
555 assert!(store.try_consume(&nonce).unwrap());
556 assert!(!store.try_consume(&nonce).unwrap());
557 assert!(store.is_consumed(&nonce).unwrap());
558 }
559
560 #[test]
561 fn try_consume_concurrent_exactly_one_winner() {
562 use std::{sync::Arc, thread};
563 let store = Arc::new(MemoryNonceStore::new());
564 let nonce = fresh_nonce();
565 let handles: Vec<_> = (0..32)
566 .map(|_| {
567 let s = Arc::clone(&store);
568 thread::spawn(move || s.try_consume(&nonce).unwrap())
569 })
570 .collect();
571 let wins: usize = handles
572 .into_iter()
573 .map(|h| h.join().unwrap() as usize)
574 .sum();
575 assert_eq!(wins, 1);
576 }
577
578 #[test]
579 fn revoke_batch_marks_all_fingerprints() {
580 let store = MemoryRevocationStore::new();
581 let fps: Vec<[u8; 32]> = (0..8u8)
582 .map(|i| {
583 let mut f = [0u8; 32];
584 f[0] = i;
585 f
586 })
587 .collect();
588 store.revoke_batch(&fps).unwrap();
589 for fp in &fps {
590 assert!(store.is_revoked(fp).unwrap());
591 }
592 }
593
594 #[test]
595 fn health_check_returns_ok_for_memory_stores() {
596 assert!(MemoryRevocationStore::new().health_check().is_ok());
597 assert!(MemoryNonceStore::new().health_check().is_ok());
598 assert!(MemoryRateLimitStore::new().health_check().is_ok());
599 }
600
601 #[test]
602 fn rate_limit_enforces_window() {
603 let store = MemoryRateLimitStore::new();
604 let key = b"test-principal";
605 for _ in 0..5 {
606 assert!(
607 store.check_and_record(key, 5, 60).unwrap(),
608 "should be allowed"
609 );
610 }
611 assert!(
612 !store.check_and_record(key, 5, 60).unwrap(),
613 "should be blocked"
614 );
615 }
616}