Skip to main content

dyolo_kya/
registry.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::{
3    atomic::{AtomicU64, Ordering},
4    RwLock,
5};
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use rand::{rngs::OsRng, RngCore};
9
10use crate::error::KyaStorageError;
11
12pub fn fresh_nonce() -> [u8; 16] {
13    let mut nonce = [0u8; 16];
14    OsRng.fill_bytes(&mut nonce);
15    nonce
16}
17
18fn unix_now() -> u64 {
19    SystemTime::now()
20        .duration_since(UNIX_EPOCH)
21        .unwrap_or_default()
22        .as_secs()
23}
24
25// ── RevocationStore ───────────────────────────────────────────────────────────
26
27pub trait RevocationStore: Send + Sync {
28    fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, KyaStorageError>;
29    fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), KyaStorageError>;
30
31    fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), KyaStorageError> {
32        for fp in fingerprints {
33            self.revoke(fp)?;
34        }
35        Ok(())
36    }
37
38    fn health_check(&self) -> Result<(), KyaStorageError> {
39        Ok(())
40    }
41}
42
43// ── NonceStore ────────────────────────────────────────────────────────────────
44
45pub trait NonceStore: Send + Sync {
46    fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError>;
47
48    fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, KyaStorageError> {
49        for nonce in nonces {
50            if self.is_consumed(nonce)? {
51                return Ok(false);
52            }
53        }
54        for nonce in nonces {
55            self.mark_consumed(nonce)?;
56        }
57        Ok(true)
58    }
59
60    fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError>;
61    fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), KyaStorageError>;
62
63    fn health_check(&self) -> Result<(), KyaStorageError> {
64        Ok(())
65    }
66}
67
68// ── RateLimitStore ────────────────────────────────────────────────────────────
69
70/// Token-bucket rate limiter storage abstraction.
71///
72/// `check_and_record` must atomically verify that the key has not exceeded its
73/// budget, then record the attempt. Returns `Ok(true)` when the request is
74/// within the limit and `Ok(false)` when the limit has been reached.
75///
76/// Implement this trait over Redis (sliding-window counter with INCR + EXPIRE)
77/// or Postgres (per-key row with an atomic UPDATE) for distributed deployments.
78pub trait RateLimitStore: Send + Sync {
79    fn check_and_record(
80        &self,
81        key: &[u8],
82        max_per_window: u32,
83        window_secs: u64,
84    ) -> Result<bool, KyaStorageError>;
85
86    fn health_check(&self) -> Result<(), KyaStorageError> {
87        Ok(())
88    }
89}
90
91// ── MemoryRevocationStore ─────────────────────────────────────────────────────
92
93pub struct MemoryRevocationStore {
94    bloom: Vec<AtomicU64>,
95    shards: Vec<RwLock<HashSet<[u8; 32]>>>,
96}
97
98impl MemoryRevocationStore {
99    #[must_use]
100    pub fn new() -> Self {
101        let bloom = (0..64).map(|_| AtomicU64::new(0)).collect();
102        let shards = (0..64).map(|_| RwLock::new(HashSet::new())).collect();
103        Self { bloom, shards }
104    }
105
106    #[inline(always)]
107    fn bloom_indices(fp: &[u8; 32]) -> [(usize, u64); 3] {
108        let h1 = u64::from_le_bytes(fp[0..8].try_into().unwrap());
109        let h2 = u64::from_le_bytes(fp[8..16].try_into().unwrap());
110        let h3 = u64::from_le_bytes(fp[16..24].try_into().unwrap());
111        [
112            ((h1 % 64) as usize, 1u64 << (h1.rotate_right(13) % 64)),
113            ((h2 % 64) as usize, 1u64 << (h2.rotate_right(27) % 64)),
114            ((h3 % 64) as usize, 1u64 << (h3.rotate_right(41) % 64)),
115        ]
116    }
117
118    #[inline(always)]
119    fn shard_index(fp: &[u8; 32]) -> usize {
120        (u64::from_le_bytes(fp[24..32].try_into().unwrap()) % 64) as usize
121    }
122}
123
124impl Default for MemoryRevocationStore {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130impl RevocationStore for MemoryRevocationStore {
131    fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, KyaStorageError> {
132        for (word, bit) in Self::bloom_indices(fingerprint) {
133            if (self.bloom[word].load(Ordering::Relaxed) & bit) == 0 {
134                return Ok(false);
135            }
136        }
137        let shard = self.shards[Self::shard_index(fingerprint)]
138            .read()
139            .map_err(|_| KyaStorageError::permanent("revocation shard lock poisoned"))?;
140        Ok(shard.contains(fingerprint))
141    }
142
143    fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), KyaStorageError> {
144        for (word, bit) in Self::bloom_indices(fingerprint) {
145            self.bloom[word].fetch_or(bit, Ordering::SeqCst);
146        }
147        self.shards[Self::shard_index(fingerprint)]
148            .write()
149            .map_err(|_| KyaStorageError::permanent("revocation shard lock poisoned"))?
150            .insert(*fingerprint);
151        Ok(())
152    }
153
154    fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), KyaStorageError> {
155        for fp in fingerprints {
156            for (word, bit) in Self::bloom_indices(fp) {
157                self.bloom[word].fetch_or(bit, Ordering::SeqCst);
158            }
159        }
160        for fp in fingerprints {
161            self.shards[Self::shard_index(fp)]
162                .write()
163                .map_err(|_| KyaStorageError::permanent("revocation shard lock poisoned"))?
164                .insert(*fp);
165        }
166        Ok(())
167    }
168}
169
170// ── MemoryNonceStore ──────────────────────────────────────────────────────────
171
172pub struct MemoryNonceStore {
173    bloom: Vec<AtomicU64>,
174    store: RwLock<HashMap<[u8; 16], u64>>,
175    ttl_secs: Option<u64>,
176}
177
178impl MemoryNonceStore {
179    const BLOOM_WORDS: usize = 1024;
180
181    #[must_use]
182    pub fn new() -> Self {
183        let bloom = (0..Self::BLOOM_WORDS).map(|_| AtomicU64::new(0)).collect();
184        Self {
185            bloom,
186            store: RwLock::new(HashMap::new()),
187            ttl_secs: None,
188        }
189    }
190
191    pub fn with_ttl_secs(mut self, ttl_secs: u64) -> Self {
192        self.ttl_secs = Some(ttl_secs);
193        self
194    }
195
196    #[inline(always)]
197    fn indices(nonce: &[u8; 16]) -> (usize, u64) {
198        let e1 = u64::from_le_bytes(nonce[0..8].try_into().unwrap());
199        let e2 = u64::from_le_bytes(nonce[8..16].try_into().unwrap());
200        let h = e1
201            .wrapping_mul(0x9E3779B185EBCA87)
202            .wrapping_add(e2.rotate_left(23));
203        let word = (h as usize) % Self::BLOOM_WORDS;
204        let bit = 1u64 << (h.rotate_right(11) % 64);
205        (word, bit)
206    }
207}
208
209impl Default for MemoryNonceStore {
210    fn default() -> Self {
211        Self::new()
212    }
213}
214
215impl NonceStore for MemoryNonceStore {
216    fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError> {
217        let (word, bit) = Self::indices(nonce);
218        let mut guard = self
219            .store
220            .write()
221            .map_err(|_| KyaStorageError::permanent("nonce store lock poisoned"))?;
222
223        let now = unix_now();
224
225        if (self.bloom[word].load(Ordering::Acquire) & bit) != 0 {
226            if let Some(&exp) = guard.get(nonce) {
227                if exp >= now {
228                    return Ok(false);
229                }
230            }
231        }
232
233        if let Some(ttl) = self.ttl_secs {
234            guard.retain(|_, exp| *exp >= now);
235            guard.insert(*nonce, now.saturating_add(ttl));
236        } else {
237            guard.insert(*nonce, u64::MAX);
238        }
239
240        self.bloom[word].fetch_or(bit, Ordering::Release);
241        Ok(true)
242    }
243
244    fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, KyaStorageError> {
245        let mut guard = self
246            .store
247            .write()
248            .map_err(|_| KyaStorageError::permanent("nonce store lock poisoned"))?;
249
250        let now = unix_now();
251
252        for nonce in nonces {
253            let (word, bit) = Self::indices(nonce);
254            if (self.bloom[word].load(Ordering::Acquire) & bit) != 0 {
255                if let Some(&exp) = guard.get(nonce) {
256                    if exp >= now {
257                        return Ok(false);
258                    }
259                }
260            }
261        }
262
263        if let Some(ttl) = self.ttl_secs {
264            guard.retain(|_, exp| *exp >= now);
265            for nonce in nonces {
266                let (word, bit) = Self::indices(nonce);
267                guard.insert(*nonce, now.saturating_add(ttl));
268                self.bloom[word].fetch_or(bit, Ordering::Release);
269            }
270        } else {
271            for nonce in nonces {
272                let (word, bit) = Self::indices(nonce);
273                guard.insert(*nonce, u64::MAX);
274                self.bloom[word].fetch_or(bit, Ordering::Release);
275            }
276        }
277
278        Ok(true)
279    }
280
281    fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError> {
282        let (word, bit) = Self::indices(nonce);
283        if (self.bloom[word].load(Ordering::Acquire) & bit) == 0 {
284            return Ok(false);
285        }
286        let guard = self
287            .store
288            .read()
289            .map_err(|_| KyaStorageError::permanent("nonce store lock poisoned"))?;
290        let now = unix_now();
291        if let Some(&exp) = guard.get(nonce) {
292            Ok(exp >= now)
293        } else {
294            Ok(false)
295        }
296    }
297
298    fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), KyaStorageError> {
299        self.try_consume(nonce).map(|_| ())
300    }
301}
302
303// ── MemoryRateLimitStore ──────────────────────────────────────────────────────
304
305struct RateBucket {
306    window_start_secs: u64,
307    count: u32,
308}
309
310/// In-process sliding-window rate limiter.
311///
312/// Each unique key (e.g. principal public key bytes, or IP address bytes) gets
313/// an independent bucket. Buckets reset at the start of each window and entries
314/// are evicted lazily on the next write after they expire.
315pub struct MemoryRateLimitStore {
316    buckets: RwLock<HashMap<Vec<u8>, RateBucket>>,
317}
318
319impl MemoryRateLimitStore {
320    #[must_use]
321    pub fn new() -> Self {
322        Self {
323            buckets: RwLock::new(HashMap::new()),
324        }
325    }
326}
327
328impl Default for MemoryRateLimitStore {
329    fn default() -> Self {
330        Self::new()
331    }
332}
333
334impl RateLimitStore for MemoryRateLimitStore {
335    fn check_and_record(
336        &self,
337        key: &[u8],
338        max_per_window: u32,
339        window_secs: u64,
340    ) -> Result<bool, KyaStorageError> {
341        let now = unix_now();
342        let window_start = (now / window_secs.max(1)) * window_secs.max(1);
343
344        let mut buckets = self
345            .buckets
346            .write()
347            .map_err(|_| KyaStorageError::permanent("rate limit store lock poisoned"))?;
348
349        buckets.retain(|_, v| now.saturating_sub(v.window_start_secs) < window_secs.max(1) * 2);
350
351        let bucket = buckets.entry(key.to_vec()).or_insert(RateBucket {
352            window_start_secs: window_start,
353            count: 0,
354        });
355
356        if bucket.window_start_secs != window_start {
357            bucket.window_start_secs = window_start;
358            bucket.count = 0;
359        }
360
361        if bucket.count >= max_per_window {
362            return Ok(false);
363        }
364
365        bucket.count += 1;
366        Ok(true)
367    }
368}
369
370// ── Async storage traits ──────────────────────────────────────────────────────
371
372#[cfg(feature = "async")]
373pub mod r#async {
374    use crate::error::KyaStorageError;
375    use async_trait::async_trait;
376    use std::sync::Arc;
377
378    #[async_trait]
379    pub trait AsyncRevocationStore: Send + Sync {
380        async fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, KyaStorageError>;
381        async fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), KyaStorageError>;
382
383        async fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), KyaStorageError> {
384            for fp in fingerprints {
385                self.revoke(fp).await?;
386            }
387            Ok(())
388        }
389
390        async fn health_check(&self) -> Result<(), KyaStorageError> {
391            Ok(())
392        }
393    }
394
395    #[async_trait]
396    pub trait AsyncNonceStore: Send + Sync {
397        async fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError>;
398
399        async fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, KyaStorageError> {
400            for nonce in nonces {
401                if self.is_consumed(nonce).await? {
402                    return Ok(false);
403                }
404            }
405            for nonce in nonces {
406                self.mark_consumed(nonce).await?;
407            }
408            Ok(true)
409        }
410
411        async fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError>;
412        async fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), KyaStorageError>;
413
414        async fn health_check(&self) -> Result<(), KyaStorageError> {
415            Ok(())
416        }
417    }
418
419    #[async_trait]
420    pub trait AsyncRateLimitStore: Send + Sync {
421        async fn check_and_record(
422            &self,
423            key: &[u8],
424            max_per_window: u32,
425            window_secs: u64,
426        ) -> Result<bool, KyaStorageError>;
427
428        async fn health_check(&self) -> Result<(), KyaStorageError> {
429            Ok(())
430        }
431    }
432
433    pub struct SyncRevocationAdapter<S>(pub Arc<S>);
434
435    #[async_trait]
436    impl<S: super::RevocationStore + 'static> AsyncRevocationStore for SyncRevocationAdapter<S> {
437        async fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, KyaStorageError> {
438            let store = Arc::clone(&self.0);
439            let fp = *fingerprint;
440            tokio::task::spawn_blocking(move || store.is_revoked(&fp))
441                .await
442                .map_err(|e| KyaStorageError::transient(e.to_string()))?
443        }
444
445        async fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), KyaStorageError> {
446            let store = Arc::clone(&self.0);
447            let fp = *fingerprint;
448            tokio::task::spawn_blocking(move || store.revoke(&fp))
449                .await
450                .map_err(|e| KyaStorageError::transient(e.to_string()))?
451        }
452
453        async fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), KyaStorageError> {
454            let store = Arc::clone(&self.0);
455            let fps: Vec<[u8; 32]> = fingerprints.to_vec();
456            tokio::task::spawn_blocking(move || store.revoke_batch(&fps))
457                .await
458                .map_err(|e| KyaStorageError::transient(e.to_string()))?
459        }
460
461        async fn health_check(&self) -> Result<(), KyaStorageError> {
462            let store = Arc::clone(&self.0);
463            tokio::task::spawn_blocking(move || store.health_check())
464                .await
465                .map_err(|e| KyaStorageError::transient(e.to_string()))?
466        }
467    }
468
469    pub struct SyncNonceAdapter<S>(pub Arc<S>);
470
471    #[async_trait]
472    impl<S: super::NonceStore + 'static> AsyncNonceStore for SyncNonceAdapter<S> {
473        async fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError> {
474            let store = Arc::clone(&self.0);
475            let n = *nonce;
476            tokio::task::spawn_blocking(move || store.try_consume(&n))
477                .await
478                .map_err(|e| KyaStorageError::transient(e.to_string()))?
479        }
480
481        async fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, KyaStorageError> {
482            let store = Arc::clone(&self.0);
483            let ns = nonces.to_vec();
484            tokio::task::spawn_blocking(move || store.try_consume_batch(&ns))
485                .await
486                .map_err(|e| KyaStorageError::transient(e.to_string()))?
487        }
488
489        async fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError> {
490            let store = Arc::clone(&self.0);
491            let n = *nonce;
492            tokio::task::spawn_blocking(move || store.is_consumed(&n))
493                .await
494                .map_err(|e| KyaStorageError::transient(e.to_string()))?
495        }
496
497        async fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), KyaStorageError> {
498            let store = Arc::clone(&self.0);
499            let n = *nonce;
500            tokio::task::spawn_blocking(move || store.mark_consumed(&n))
501                .await
502                .map_err(|e| KyaStorageError::transient(e.to_string()))?
503        }
504
505        async fn health_check(&self) -> Result<(), KyaStorageError> {
506            let store = Arc::clone(&self.0);
507            tokio::task::spawn_blocking(move || store.health_check())
508                .await
509                .map_err(|e| KyaStorageError::transient(e.to_string()))?
510        }
511    }
512}
513
514// ── Tests ─────────────────────────────────────────────────────────────────────
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519
520    #[test]
521    fn try_consume_is_atomic_and_idempotent() {
522        let store = MemoryNonceStore::new();
523        let nonce = fresh_nonce();
524        assert!(store.try_consume(&nonce).unwrap());
525        assert!(!store.try_consume(&nonce).unwrap());
526        assert!(store.is_consumed(&nonce).unwrap());
527    }
528
529    #[test]
530    fn try_consume_concurrent_exactly_one_winner() {
531        use std::{sync::Arc, thread};
532        let store = Arc::new(MemoryNonceStore::new());
533        let nonce = fresh_nonce();
534        let handles: Vec<_> = (0..32)
535            .map(|_| {
536                let s = Arc::clone(&store);
537                thread::spawn(move || s.try_consume(&nonce).unwrap())
538            })
539            .collect();
540        let wins: usize = handles
541            .into_iter()
542            .map(|h| h.join().unwrap() as usize)
543            .sum();
544        assert_eq!(wins, 1);
545    }
546
547    #[test]
548    fn revoke_batch_marks_all_fingerprints() {
549        let store = MemoryRevocationStore::new();
550        let fps: Vec<[u8; 32]> = (0..8u8)
551            .map(|i| {
552                let mut f = [0u8; 32];
553                f[0] = i;
554                f
555            })
556            .collect();
557        store.revoke_batch(&fps).unwrap();
558        for fp in &fps {
559            assert!(store.is_revoked(fp).unwrap());
560        }
561    }
562
563    #[test]
564    fn health_check_returns_ok_for_memory_stores() {
565        assert!(MemoryRevocationStore::new().health_check().is_ok());
566        assert!(MemoryNonceStore::new().health_check().is_ok());
567        assert!(MemoryRateLimitStore::new().health_check().is_ok());
568    }
569
570    #[test]
571    fn rate_limit_enforces_window() {
572        let store = MemoryRateLimitStore::new();
573        let key = b"test-principal";
574        for _ in 0..5 {
575            assert!(
576                store.check_and_record(key, 5, 60).unwrap(),
577                "should be allowed"
578            );
579        }
580        assert!(
581            !store.check_and_record(key, 5, 60).unwrap(),
582            "should be blocked"
583        );
584    }
585}