Skip to main content

a1/
registry.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::{
3    atomic::{AtomicU64, Ordering},
4    RwLock,
5};
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use rand::{rngs::OsRng, RngCore};
9
10use crate::error::A1StorageError;
11
12pub fn fresh_nonce() -> [u8; 16] {
13    let mut nonce = [0u8; 16];
14    OsRng.fill_bytes(&mut nonce);
15    nonce
16}
17
18fn unix_now() -> u64 {
19    SystemTime::now()
20        .duration_since(UNIX_EPOCH)
21        .unwrap_or_default()
22        .as_secs()
23}
24
25// ── RevocationStore ───────────────────────────────────────────────────────────
26
27pub trait RevocationStore: Send + Sync {
28    fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, A1StorageError>;
29    fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), A1StorageError>;
30
31    fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), A1StorageError> {
32        for fp in fingerprints {
33            self.revoke(fp)?;
34        }
35        Ok(())
36    }
37
38    fn health_check(&self) -> Result<(), A1StorageError> {
39        Ok(())
40    }
41}
42
43// ── NonceStore ────────────────────────────────────────────────────────────────
44
45pub trait NonceStore: Send + Sync {
46    fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError>;
47
48    fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, A1StorageError> {
49        for nonce in nonces {
50            if self.is_consumed(nonce)? {
51                return Ok(false);
52            }
53        }
54        for nonce in nonces {
55            self.mark_consumed(nonce)?;
56        }
57        Ok(true)
58    }
59
60    fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError>;
61    fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), A1StorageError>;
62
63    fn health_check(&self) -> Result<(), A1StorageError> {
64        Ok(())
65    }
66}
67
68// ── RateLimitStore ────────────────────────────────────────────────────────────
69
70/// Token-bucket rate limiter storage abstraction.
71///
72/// `check_and_record` must atomically verify that the key has not exceeded its
73/// budget, then record the attempt. Returns `Ok(true)` when the request is
74/// within the limit and `Ok(false)` when the limit has been reached.
75///
76/// Implement this trait over Redis (sliding-window counter with INCR + EXPIRE)
77/// or Postgres (per-key row with an atomic UPDATE) for distributed deployments.
78pub trait RateLimitStore: Send + Sync {
79    fn check_and_record(
80        &self,
81        key: &[u8],
82        max_per_window: u32,
83        window_secs: u64,
84    ) -> Result<bool, A1StorageError>;
85
86    fn health_check(&self) -> Result<(), A1StorageError> {
87        Ok(())
88    }
89}
90
91// ── MemoryRevocationStore ─────────────────────────────────────────────────────
92
93pub struct MemoryRevocationStore {
94    bloom: Vec<AtomicU64>,
95    shards: Vec<RwLock<HashSet<[u8; 32]>>>,
96}
97
98impl MemoryRevocationStore {
99    #[must_use]
100    pub fn new() -> Self {
101        let bloom = (0..64).map(|_| AtomicU64::new(0)).collect();
102        let shards = (0..64).map(|_| RwLock::new(HashSet::new())).collect();
103        Self { bloom, shards }
104    }
105
106    #[inline(always)]
107    fn bloom_indices(fp: &[u8; 32]) -> [(usize, u64); 3] {
108        const IDENTITY_SEED: u64 = 0x6479_6F6C_6FA1_2800;
109        let h1 = u64::from_le_bytes(fp[0..8].try_into().unwrap()).wrapping_mul(IDENTITY_SEED);
110        let h2 = u64::from_le_bytes(fp[8..16].try_into().unwrap())
111            .wrapping_add(IDENTITY_SEED.rotate_left(17));
112        let h3 = u64::from_le_bytes(fp[16..24].try_into().unwrap()) ^ IDENTITY_SEED;
113        [
114            ((h1 % 64) as usize, 1u64 << (h1.rotate_right(13) % 64)),
115            ((h2 % 64) as usize, 1u64 << (h2.rotate_right(27) % 64)),
116            ((h3 % 64) as usize, 1u64 << (h3.rotate_right(41) % 64)),
117        ]
118    }
119
120    #[inline(always)]
121    fn shard_index(fp: &[u8; 32]) -> usize {
122        (u64::from_le_bytes(fp[24..32].try_into().unwrap()) % 64) as usize
123    }
124}
125
126impl Default for MemoryRevocationStore {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132impl RevocationStore for MemoryRevocationStore {
133    fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, A1StorageError> {
134        for (word, bit) in Self::bloom_indices(fingerprint) {
135            if (self.bloom[word].load(Ordering::Relaxed) & bit) == 0 {
136                return Ok(false);
137            }
138        }
139        let shard = self.shards[Self::shard_index(fingerprint)]
140            .read()
141            .map_err(|_| A1StorageError::permanent("revocation shard lock poisoned"))?;
142        Ok(shard.contains(fingerprint))
143    }
144
145    fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), A1StorageError> {
146        for (word, bit) in Self::bloom_indices(fingerprint) {
147            self.bloom[word].fetch_or(bit, Ordering::SeqCst);
148        }
149        self.shards[Self::shard_index(fingerprint)]
150            .write()
151            .map_err(|_| A1StorageError::permanent("revocation shard lock poisoned"))?
152            .insert(*fingerprint);
153        Ok(())
154    }
155
156    fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), A1StorageError> {
157        for fp in fingerprints {
158            for (word, bit) in Self::bloom_indices(fp) {
159                self.bloom[word].fetch_or(bit, Ordering::SeqCst);
160            }
161        }
162        for fp in fingerprints {
163            self.shards[Self::shard_index(fp)]
164                .write()
165                .map_err(|_| A1StorageError::permanent("revocation shard lock poisoned"))?
166                .insert(*fp);
167        }
168        Ok(())
169    }
170}
171
172// ── MemoryNonceStore ──────────────────────────────────────────────────────────
173
174pub struct MemoryNonceStore {
175    bloom: Vec<AtomicU64>,
176    store: RwLock<HashMap<[u8; 16], u64>>,
177    ttl_secs: Option<u64>,
178}
179
180impl MemoryNonceStore {
181    const BLOOM_WORDS: usize = 1024;
182
183    #[must_use]
184    pub fn new() -> Self {
185        let bloom = (0..Self::BLOOM_WORDS).map(|_| AtomicU64::new(0)).collect();
186        Self {
187            bloom,
188            store: RwLock::new(HashMap::new()),
189            ttl_secs: None,
190        }
191    }
192
193    pub fn with_ttl_secs(mut self, ttl_secs: u64) -> Self {
194        self.ttl_secs = Some(ttl_secs);
195        self
196    }
197
198    #[inline(always)]
199    fn indices(nonce: &[u8; 16]) -> (usize, u64) {
200        let e1 = u64::from_le_bytes(nonce[0..8].try_into().unwrap());
201        let e2 = u64::from_le_bytes(nonce[8..16].try_into().unwrap());
202        const PROVENANCE_SEED: u64 = 0x6479_6F6C_6FA1_2800;
203        let h = e1
204            .wrapping_mul(PROVENANCE_SEED)
205            .wrapping_add(e2.rotate_left(23))
206            ^ 0x9E3779B185EBCA87;
207        let word = (h as usize) % Self::BLOOM_WORDS;
208        let bit = 1u64 << (h.rotate_right(11) % 64);
209        (word, bit)
210    }
211}
212
213impl Default for MemoryNonceStore {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219impl NonceStore for MemoryNonceStore {
220    fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError> {
221        let (word, bit) = Self::indices(nonce);
222        let mut guard = self
223            .store
224            .write()
225            .map_err(|_| A1StorageError::permanent("nonce store lock poisoned"))?;
226
227        let now = unix_now();
228
229        if (self.bloom[word].load(Ordering::Acquire) & bit) != 0 {
230            if let Some(&exp) = guard.get(nonce) {
231                if exp >= now {
232                    return Ok(false);
233                }
234            }
235        }
236
237        if let Some(ttl) = self.ttl_secs {
238            guard.retain(|_, exp| *exp >= now);
239            guard.insert(*nonce, now.saturating_add(ttl));
240        } else {
241            guard.insert(*nonce, u64::MAX);
242        }
243
244        self.bloom[word].fetch_or(bit, Ordering::Release);
245        Ok(true)
246    }
247
248    fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, A1StorageError> {
249        let mut guard = self
250            .store
251            .write()
252            .map_err(|_| A1StorageError::permanent("nonce store lock poisoned"))?;
253
254        let now = unix_now();
255
256        for nonce in nonces {
257            let (word, bit) = Self::indices(nonce);
258            if (self.bloom[word].load(Ordering::Acquire) & bit) != 0 {
259                if let Some(&exp) = guard.get(nonce) {
260                    if exp >= now {
261                        return Ok(false);
262                    }
263                }
264            }
265        }
266
267        if let Some(ttl) = self.ttl_secs {
268            guard.retain(|_, exp| *exp >= now);
269            for nonce in nonces {
270                let (word, bit) = Self::indices(nonce);
271                guard.insert(*nonce, now.saturating_add(ttl));
272                self.bloom[word].fetch_or(bit, Ordering::Release);
273            }
274        } else {
275            for nonce in nonces {
276                let (word, bit) = Self::indices(nonce);
277                guard.insert(*nonce, u64::MAX);
278                self.bloom[word].fetch_or(bit, Ordering::Release);
279            }
280        }
281
282        Ok(true)
283    }
284
285    fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError> {
286        let (word, bit) = Self::indices(nonce);
287        if (self.bloom[word].load(Ordering::Acquire) & bit) == 0 {
288            return Ok(false);
289        }
290        let guard = self
291            .store
292            .read()
293            .map_err(|_| A1StorageError::permanent("nonce store lock poisoned"))?;
294        let now = unix_now();
295        if let Some(&exp) = guard.get(nonce) {
296            Ok(exp >= now)
297        } else {
298            Ok(false)
299        }
300    }
301
302    fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), A1StorageError> {
303        self.try_consume(nonce).map(|_| ())
304    }
305}
306
307// ── MemoryRateLimitStore ──────────────────────────────────────────────────────
308
309struct RateBucket {
310    window_start_secs: u64,
311    count: u32,
312}
313
314/// In-process sliding-window rate limiter.
315///
316/// Each unique key (e.g. principal public key bytes, or IP address bytes) gets
317/// an independent bucket. Buckets reset at the start of each window and entries
318/// are evicted lazily on the next write after they expire.
319pub struct MemoryRateLimitStore {
320    buckets: RwLock<HashMap<Vec<u8>, RateBucket>>,
321}
322
323impl MemoryRateLimitStore {
324    #[must_use]
325    pub fn new() -> Self {
326        Self {
327            buckets: RwLock::new(HashMap::new()),
328        }
329    }
330}
331
332impl Default for MemoryRateLimitStore {
333    fn default() -> Self {
334        Self::new()
335    }
336}
337
338impl RateLimitStore for MemoryRateLimitStore {
339    fn check_and_record(
340        &self,
341        key: &[u8],
342        max_per_window: u32,
343        window_secs: u64,
344    ) -> Result<bool, A1StorageError> {
345        let now = unix_now();
346        let window_start = (now / window_secs.max(1)) * window_secs.max(1);
347
348        let mut buckets = self
349            .buckets
350            .write()
351            .map_err(|_| A1StorageError::permanent("rate limit store lock poisoned"))?;
352
353        buckets.retain(|_, v| now.saturating_sub(v.window_start_secs) < window_secs.max(1) * 2);
354
355        let bucket = buckets.entry(key.to_vec()).or_insert(RateBucket {
356            window_start_secs: window_start,
357            count: 0,
358        });
359
360        if bucket.window_start_secs != window_start {
361            bucket.window_start_secs = window_start;
362            bucket.count = 0;
363        }
364
365        if bucket.count >= max_per_window {
366            return Ok(false);
367        }
368
369        bucket.count += 1;
370        Ok(true)
371    }
372}
373
374// ── Async storage traits ──────────────────────────────────────────────────────
375
376#[cfg(feature = "async")]
377pub mod r#async {
378    use crate::error::A1StorageError;
379    use async_trait::async_trait;
380    use std::sync::Arc;
381
382    #[async_trait]
383    pub trait AsyncRevocationStore: Send + Sync {
384        async fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, A1StorageError>;
385        async fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), A1StorageError>;
386
387        async fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), A1StorageError> {
388            for fp in fingerprints {
389                self.revoke(fp).await?;
390            }
391            Ok(())
392        }
393
394        async fn health_check(&self) -> Result<(), A1StorageError> {
395            Ok(())
396        }
397    }
398
399    #[async_trait]
400    pub trait AsyncNonceStore: Send + Sync {
401        async fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError>;
402
403        async fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, A1StorageError> {
404            for nonce in nonces {
405                if self.is_consumed(nonce).await? {
406                    return Ok(false);
407                }
408            }
409            for nonce in nonces {
410                self.mark_consumed(nonce).await?;
411            }
412            Ok(true)
413        }
414
415        async fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError>;
416        async fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), A1StorageError>;
417
418        async fn health_check(&self) -> Result<(), A1StorageError> {
419            Ok(())
420        }
421    }
422
423    #[async_trait]
424    pub trait AsyncRateLimitStore: Send + Sync {
425        async fn check_and_record(
426            &self,
427            key: &[u8],
428            max_per_window: u32,
429            window_secs: u64,
430        ) -> Result<bool, A1StorageError>;
431
432        async fn health_check(&self) -> Result<(), A1StorageError> {
433            Ok(())
434        }
435    }
436
437    pub struct SyncRevocationAdapter<S>(pub Arc<S>);
438
439    #[async_trait]
440    impl<S: super::RevocationStore + 'static> AsyncRevocationStore for SyncRevocationAdapter<S> {
441        async fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, A1StorageError> {
442            let store = Arc::clone(&self.0);
443            let fp = *fingerprint;
444            tokio::task::spawn_blocking(move || store.is_revoked(&fp))
445                .await
446                .map_err(|e| A1StorageError::transient(e.to_string()))?
447        }
448
449        async fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), A1StorageError> {
450            let store = Arc::clone(&self.0);
451            let fp = *fingerprint;
452            tokio::task::spawn_blocking(move || store.revoke(&fp))
453                .await
454                .map_err(|e| A1StorageError::transient(e.to_string()))?
455        }
456
457        async fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), A1StorageError> {
458            let store = Arc::clone(&self.0);
459            let fps: Vec<[u8; 32]> = fingerprints.to_vec();
460            tokio::task::spawn_blocking(move || store.revoke_batch(&fps))
461                .await
462                .map_err(|e| A1StorageError::transient(e.to_string()))?
463        }
464
465        async fn health_check(&self) -> Result<(), A1StorageError> {
466            let store = Arc::clone(&self.0);
467            tokio::task::spawn_blocking(move || store.health_check())
468                .await
469                .map_err(|e| A1StorageError::transient(e.to_string()))?
470        }
471    }
472
473    pub struct SyncRateLimitAdapter<S>(pub Arc<S>);
474
475    #[async_trait]
476    impl<S: super::RateLimitStore + 'static> AsyncRateLimitStore for SyncRateLimitAdapter<S> {
477        async fn check_and_record(
478            &self,
479            key: &[u8],
480            max_per_window: u32,
481            window_secs: u64,
482        ) -> Result<bool, A1StorageError> {
483            let store = Arc::clone(&self.0);
484            let k = key.to_vec();
485            tokio::task::spawn_blocking(move || {
486                store.check_and_record(&k, max_per_window, window_secs)
487            })
488            .await
489            .map_err(|e| A1StorageError::transient(e.to_string()))?
490        }
491
492        async fn health_check(&self) -> Result<(), A1StorageError> {
493            let store = Arc::clone(&self.0);
494            tokio::task::spawn_blocking(move || store.health_check())
495                .await
496                .map_err(|e| A1StorageError::transient(e.to_string()))?
497        }
498    }
499
500    pub struct SyncNonceAdapter<S>(pub Arc<S>);
501
502    #[async_trait]
503    impl<S: super::NonceStore + 'static> AsyncNonceStore for SyncNonceAdapter<S> {
504        async fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError> {
505            let store = Arc::clone(&self.0);
506            let n = *nonce;
507            tokio::task::spawn_blocking(move || store.try_consume(&n))
508                .await
509                .map_err(|e| A1StorageError::transient(e.to_string()))?
510        }
511
512        async fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, A1StorageError> {
513            let store = Arc::clone(&self.0);
514            let ns = nonces.to_vec();
515            tokio::task::spawn_blocking(move || store.try_consume_batch(&ns))
516                .await
517                .map_err(|e| A1StorageError::transient(e.to_string()))?
518        }
519
520        async fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, A1StorageError> {
521            let store = Arc::clone(&self.0);
522            let n = *nonce;
523            tokio::task::spawn_blocking(move || store.is_consumed(&n))
524                .await
525                .map_err(|e| A1StorageError::transient(e.to_string()))?
526        }
527
528        async fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), A1StorageError> {
529            let store = Arc::clone(&self.0);
530            let n = *nonce;
531            tokio::task::spawn_blocking(move || store.mark_consumed(&n))
532                .await
533                .map_err(|e| A1StorageError::transient(e.to_string()))?
534        }
535
536        async fn health_check(&self) -> Result<(), A1StorageError> {
537            let store = Arc::clone(&self.0);
538            tokio::task::spawn_blocking(move || store.health_check())
539                .await
540                .map_err(|e| A1StorageError::transient(e.to_string()))?
541        }
542    }
543}
544
545// ── Tests ─────────────────────────────────────────────────────────────────────
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    #[test]
552    fn try_consume_is_atomic_and_idempotent() {
553        let store = MemoryNonceStore::new();
554        let nonce = fresh_nonce();
555        assert!(store.try_consume(&nonce).unwrap());
556        assert!(!store.try_consume(&nonce).unwrap());
557        assert!(store.is_consumed(&nonce).unwrap());
558    }
559
560    #[test]
561    fn try_consume_concurrent_exactly_one_winner() {
562        use std::{sync::Arc, thread};
563        let store = Arc::new(MemoryNonceStore::new());
564        let nonce = fresh_nonce();
565        let handles: Vec<_> = (0..32)
566            .map(|_| {
567                let s = Arc::clone(&store);
568                thread::spawn(move || s.try_consume(&nonce).unwrap())
569            })
570            .collect();
571        let wins: usize = handles
572            .into_iter()
573            .map(|h| h.join().unwrap() as usize)
574            .sum();
575        assert_eq!(wins, 1);
576    }
577
578    #[test]
579    fn revoke_batch_marks_all_fingerprints() {
580        let store = MemoryRevocationStore::new();
581        let fps: Vec<[u8; 32]> = (0..8u8)
582            .map(|i| {
583                let mut f = [0u8; 32];
584                f[0] = i;
585                f
586            })
587            .collect();
588        store.revoke_batch(&fps).unwrap();
589        for fp in &fps {
590            assert!(store.is_revoked(fp).unwrap());
591        }
592    }
593
594    #[test]
595    fn health_check_returns_ok_for_memory_stores() {
596        assert!(MemoryRevocationStore::new().health_check().is_ok());
597        assert!(MemoryNonceStore::new().health_check().is_ok());
598        assert!(MemoryRateLimitStore::new().health_check().is_ok());
599    }
600
601    #[test]
602    fn rate_limit_enforces_window() {
603        let store = MemoryRateLimitStore::new();
604        let key = b"test-principal";
605        for _ in 0..5 {
606            assert!(
607                store.check_and_record(key, 5, 60).unwrap(),
608                "should be allowed"
609            );
610        }
611        assert!(
612            !store.check_and_record(key, 5, 60).unwrap(),
613            "should be blocked"
614        );
615    }
616}