1use std::collections::{HashMap, HashSet};
2use std::sync::{
3 atomic::{AtomicU64, Ordering},
4 RwLock,
5};
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use rand::{rngs::OsRng, RngCore};
9
10use crate::error::KyaStorageError;
11
12pub fn fresh_nonce() -> [u8; 16] {
13 let mut nonce = [0u8; 16];
14 OsRng.fill_bytes(&mut nonce);
15 nonce
16}
17
18fn unix_now() -> u64 {
19 SystemTime::now()
20 .duration_since(UNIX_EPOCH)
21 .unwrap_or_default()
22 .as_secs()
23}
24
25pub trait RevocationStore: Send + Sync {
28 fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, KyaStorageError>;
29 fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), KyaStorageError>;
30
31 fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), KyaStorageError> {
32 for fp in fingerprints {
33 self.revoke(fp)?;
34 }
35 Ok(())
36 }
37
38 fn health_check(&self) -> Result<(), KyaStorageError> {
39 Ok(())
40 }
41}
42
43pub trait NonceStore: Send + Sync {
46 fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError>;
47
48 fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, KyaStorageError> {
49 for nonce in nonces {
50 if self.is_consumed(nonce)? {
51 return Ok(false);
52 }
53 }
54 for nonce in nonces {
55 self.mark_consumed(nonce)?;
56 }
57 Ok(true)
58 }
59
60 fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError>;
61 fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), KyaStorageError>;
62
63 fn health_check(&self) -> Result<(), KyaStorageError> {
64 Ok(())
65 }
66}
67
68pub trait RateLimitStore: Send + Sync {
79 fn check_and_record(
80 &self,
81 key: &[u8],
82 max_per_window: u32,
83 window_secs: u64,
84 ) -> Result<bool, KyaStorageError>;
85
86 fn health_check(&self) -> Result<(), KyaStorageError> {
87 Ok(())
88 }
89}
90
91pub struct MemoryRevocationStore {
94 bloom: Vec<AtomicU64>,
95 shards: Vec<RwLock<HashSet<[u8; 32]>>>,
96}
97
98impl MemoryRevocationStore {
99 #[must_use]
100 pub fn new() -> Self {
101 let bloom = (0..64).map(|_| AtomicU64::new(0)).collect();
102 let shards = (0..64).map(|_| RwLock::new(HashSet::new())).collect();
103 Self { bloom, shards }
104 }
105
106 #[inline(always)]
107 fn bloom_indices(fp: &[u8; 32]) -> [(usize, u64); 3] {
108 let h1 = u64::from_le_bytes(fp[0..8].try_into().unwrap());
109 let h2 = u64::from_le_bytes(fp[8..16].try_into().unwrap());
110 let h3 = u64::from_le_bytes(fp[16..24].try_into().unwrap());
111 [
112 ((h1 % 64) as usize, 1u64 << (h1.rotate_right(13) % 64)),
113 ((h2 % 64) as usize, 1u64 << (h2.rotate_right(27) % 64)),
114 ((h3 % 64) as usize, 1u64 << (h3.rotate_right(41) % 64)),
115 ]
116 }
117
118 #[inline(always)]
119 fn shard_index(fp: &[u8; 32]) -> usize {
120 (u64::from_le_bytes(fp[24..32].try_into().unwrap()) % 64) as usize
121 }
122}
123
124impl Default for MemoryRevocationStore {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130impl RevocationStore for MemoryRevocationStore {
131 fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, KyaStorageError> {
132 for (word, bit) in Self::bloom_indices(fingerprint) {
133 if (self.bloom[word].load(Ordering::Relaxed) & bit) == 0 {
134 return Ok(false);
135 }
136 }
137 let shard = self.shards[Self::shard_index(fingerprint)]
138 .read()
139 .map_err(|_| KyaStorageError::permanent("revocation shard lock poisoned"))?;
140 Ok(shard.contains(fingerprint))
141 }
142
143 fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), KyaStorageError> {
144 for (word, bit) in Self::bloom_indices(fingerprint) {
145 self.bloom[word].fetch_or(bit, Ordering::SeqCst);
146 }
147 self.shards[Self::shard_index(fingerprint)]
148 .write()
149 .map_err(|_| KyaStorageError::permanent("revocation shard lock poisoned"))?
150 .insert(*fingerprint);
151 Ok(())
152 }
153
154 fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), KyaStorageError> {
155 for fp in fingerprints {
156 for (word, bit) in Self::bloom_indices(fp) {
157 self.bloom[word].fetch_or(bit, Ordering::SeqCst);
158 }
159 }
160 for fp in fingerprints {
161 self.shards[Self::shard_index(fp)]
162 .write()
163 .map_err(|_| KyaStorageError::permanent("revocation shard lock poisoned"))?
164 .insert(*fp);
165 }
166 Ok(())
167 }
168}
169
170pub struct MemoryNonceStore {
173 bloom: Vec<AtomicU64>,
174 store: RwLock<HashMap<[u8; 16], u64>>,
175 ttl_secs: Option<u64>,
176}
177
178impl MemoryNonceStore {
179 const BLOOM_WORDS: usize = 1024;
180
181 #[must_use]
182 pub fn new() -> Self {
183 let bloom = (0..Self::BLOOM_WORDS).map(|_| AtomicU64::new(0)).collect();
184 Self {
185 bloom,
186 store: RwLock::new(HashMap::new()),
187 ttl_secs: None,
188 }
189 }
190
191 pub fn with_ttl_secs(mut self, ttl_secs: u64) -> Self {
192 self.ttl_secs = Some(ttl_secs);
193 self
194 }
195
196 #[inline(always)]
197 fn indices(nonce: &[u8; 16]) -> (usize, u64) {
198 let e1 = u64::from_le_bytes(nonce[0..8].try_into().unwrap());
199 let e2 = u64::from_le_bytes(nonce[8..16].try_into().unwrap());
200 let h = e1
201 .wrapping_mul(0x9E3779B185EBCA87)
202 .wrapping_add(e2.rotate_left(23));
203 let word = (h as usize) % Self::BLOOM_WORDS;
204 let bit = 1u64 << (h.rotate_right(11) % 64);
205 (word, bit)
206 }
207}
208
209impl Default for MemoryNonceStore {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215impl NonceStore for MemoryNonceStore {
216 fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError> {
217 let (word, bit) = Self::indices(nonce);
218 let mut guard = self
219 .store
220 .write()
221 .map_err(|_| KyaStorageError::permanent("nonce store lock poisoned"))?;
222
223 let now = unix_now();
224
225 if (self.bloom[word].load(Ordering::Acquire) & bit) != 0 {
226 if let Some(&exp) = guard.get(nonce) {
227 if exp >= now {
228 return Ok(false);
229 }
230 }
231 }
232
233 if let Some(ttl) = self.ttl_secs {
234 guard.retain(|_, exp| *exp >= now);
235 guard.insert(*nonce, now.saturating_add(ttl));
236 } else {
237 guard.insert(*nonce, u64::MAX);
238 }
239
240 self.bloom[word].fetch_or(bit, Ordering::Release);
241 Ok(true)
242 }
243
244 fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, KyaStorageError> {
245 let mut guard = self
246 .store
247 .write()
248 .map_err(|_| KyaStorageError::permanent("nonce store lock poisoned"))?;
249
250 let now = unix_now();
251
252 for nonce in nonces {
253 let (word, bit) = Self::indices(nonce);
254 if (self.bloom[word].load(Ordering::Acquire) & bit) != 0 {
255 if let Some(&exp) = guard.get(nonce) {
256 if exp >= now {
257 return Ok(false);
258 }
259 }
260 }
261 }
262
263 if let Some(ttl) = self.ttl_secs {
264 guard.retain(|_, exp| *exp >= now);
265 for nonce in nonces {
266 let (word, bit) = Self::indices(nonce);
267 guard.insert(*nonce, now.saturating_add(ttl));
268 self.bloom[word].fetch_or(bit, Ordering::Release);
269 }
270 } else {
271 for nonce in nonces {
272 let (word, bit) = Self::indices(nonce);
273 guard.insert(*nonce, u64::MAX);
274 self.bloom[word].fetch_or(bit, Ordering::Release);
275 }
276 }
277
278 Ok(true)
279 }
280
281 fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError> {
282 let (word, bit) = Self::indices(nonce);
283 if (self.bloom[word].load(Ordering::Acquire) & bit) == 0 {
284 return Ok(false);
285 }
286 let guard = self
287 .store
288 .read()
289 .map_err(|_| KyaStorageError::permanent("nonce store lock poisoned"))?;
290 let now = unix_now();
291 if let Some(&exp) = guard.get(nonce) {
292 Ok(exp >= now)
293 } else {
294 Ok(false)
295 }
296 }
297
298 fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), KyaStorageError> {
299 self.try_consume(nonce).map(|_| ())
300 }
301}
302
303struct RateBucket {
306 window_start_secs: u64,
307 count: u32,
308}
309
310pub struct MemoryRateLimitStore {
316 buckets: RwLock<HashMap<Vec<u8>, RateBucket>>,
317}
318
319impl MemoryRateLimitStore {
320 #[must_use]
321 pub fn new() -> Self {
322 Self {
323 buckets: RwLock::new(HashMap::new()),
324 }
325 }
326}
327
328impl Default for MemoryRateLimitStore {
329 fn default() -> Self {
330 Self::new()
331 }
332}
333
334impl RateLimitStore for MemoryRateLimitStore {
335 fn check_and_record(
336 &self,
337 key: &[u8],
338 max_per_window: u32,
339 window_secs: u64,
340 ) -> Result<bool, KyaStorageError> {
341 let now = unix_now();
342 let window_start = (now / window_secs.max(1)) * window_secs.max(1);
343
344 let mut buckets = self
345 .buckets
346 .write()
347 .map_err(|_| KyaStorageError::permanent("rate limit store lock poisoned"))?;
348
349 buckets.retain(|_, v| now.saturating_sub(v.window_start_secs) < window_secs.max(1) * 2);
350
351 let bucket = buckets.entry(key.to_vec()).or_insert(RateBucket {
352 window_start_secs: window_start,
353 count: 0,
354 });
355
356 if bucket.window_start_secs != window_start {
357 bucket.window_start_secs = window_start;
358 bucket.count = 0;
359 }
360
361 if bucket.count >= max_per_window {
362 return Ok(false);
363 }
364
365 bucket.count += 1;
366 Ok(true)
367 }
368}
369
370#[cfg(feature = "async")]
373pub mod r#async {
374 use crate::error::KyaStorageError;
375 use async_trait::async_trait;
376 use std::sync::Arc;
377
378 #[async_trait]
379 pub trait AsyncRevocationStore: Send + Sync {
380 async fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, KyaStorageError>;
381 async fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), KyaStorageError>;
382
383 async fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), KyaStorageError> {
384 for fp in fingerprints {
385 self.revoke(fp).await?;
386 }
387 Ok(())
388 }
389
390 async fn health_check(&self) -> Result<(), KyaStorageError> {
391 Ok(())
392 }
393 }
394
395 #[async_trait]
396 pub trait AsyncNonceStore: Send + Sync {
397 async fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError>;
398
399 async fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, KyaStorageError> {
400 for nonce in nonces {
401 if self.is_consumed(nonce).await? {
402 return Ok(false);
403 }
404 }
405 for nonce in nonces {
406 self.mark_consumed(nonce).await?;
407 }
408 Ok(true)
409 }
410
411 async fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError>;
412 async fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), KyaStorageError>;
413
414 async fn health_check(&self) -> Result<(), KyaStorageError> {
415 Ok(())
416 }
417 }
418
419 #[async_trait]
420 pub trait AsyncRateLimitStore: Send + Sync {
421 async fn check_and_record(
422 &self,
423 key: &[u8],
424 max_per_window: u32,
425 window_secs: u64,
426 ) -> Result<bool, KyaStorageError>;
427
428 async fn health_check(&self) -> Result<(), KyaStorageError> {
429 Ok(())
430 }
431 }
432
433 pub struct SyncRevocationAdapter<S>(pub Arc<S>);
434
435 #[async_trait]
436 impl<S: super::RevocationStore + 'static> AsyncRevocationStore for SyncRevocationAdapter<S> {
437 async fn is_revoked(&self, fingerprint: &[u8; 32]) -> Result<bool, KyaStorageError> {
438 let store = Arc::clone(&self.0);
439 let fp = *fingerprint;
440 tokio::task::spawn_blocking(move || store.is_revoked(&fp))
441 .await
442 .map_err(|e| KyaStorageError::transient(e.to_string()))?
443 }
444
445 async fn revoke(&self, fingerprint: &[u8; 32]) -> Result<(), KyaStorageError> {
446 let store = Arc::clone(&self.0);
447 let fp = *fingerprint;
448 tokio::task::spawn_blocking(move || store.revoke(&fp))
449 .await
450 .map_err(|e| KyaStorageError::transient(e.to_string()))?
451 }
452
453 async fn revoke_batch(&self, fingerprints: &[[u8; 32]]) -> Result<(), KyaStorageError> {
454 let store = Arc::clone(&self.0);
455 let fps: Vec<[u8; 32]> = fingerprints.to_vec();
456 tokio::task::spawn_blocking(move || store.revoke_batch(&fps))
457 .await
458 .map_err(|e| KyaStorageError::transient(e.to_string()))?
459 }
460
461 async fn health_check(&self) -> Result<(), KyaStorageError> {
462 let store = Arc::clone(&self.0);
463 tokio::task::spawn_blocking(move || store.health_check())
464 .await
465 .map_err(|e| KyaStorageError::transient(e.to_string()))?
466 }
467 }
468
469 pub struct SyncNonceAdapter<S>(pub Arc<S>);
470
471 #[async_trait]
472 impl<S: super::NonceStore + 'static> AsyncNonceStore for SyncNonceAdapter<S> {
473 async fn try_consume(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError> {
474 let store = Arc::clone(&self.0);
475 let n = *nonce;
476 tokio::task::spawn_blocking(move || store.try_consume(&n))
477 .await
478 .map_err(|e| KyaStorageError::transient(e.to_string()))?
479 }
480
481 async fn try_consume_batch(&self, nonces: &[[u8; 16]]) -> Result<bool, KyaStorageError> {
482 let store = Arc::clone(&self.0);
483 let ns = nonces.to_vec();
484 tokio::task::spawn_blocking(move || store.try_consume_batch(&ns))
485 .await
486 .map_err(|e| KyaStorageError::transient(e.to_string()))?
487 }
488
489 async fn is_consumed(&self, nonce: &[u8; 16]) -> Result<bool, KyaStorageError> {
490 let store = Arc::clone(&self.0);
491 let n = *nonce;
492 tokio::task::spawn_blocking(move || store.is_consumed(&n))
493 .await
494 .map_err(|e| KyaStorageError::transient(e.to_string()))?
495 }
496
497 async fn mark_consumed(&self, nonce: &[u8; 16]) -> Result<(), KyaStorageError> {
498 let store = Arc::clone(&self.0);
499 let n = *nonce;
500 tokio::task::spawn_blocking(move || store.mark_consumed(&n))
501 .await
502 .map_err(|e| KyaStorageError::transient(e.to_string()))?
503 }
504
505 async fn health_check(&self) -> Result<(), KyaStorageError> {
506 let store = Arc::clone(&self.0);
507 tokio::task::spawn_blocking(move || store.health_check())
508 .await
509 .map_err(|e| KyaStorageError::transient(e.to_string()))?
510 }
511 }
512}
513
514#[cfg(test)]
517mod tests {
518 use super::*;
519
520 #[test]
521 fn try_consume_is_atomic_and_idempotent() {
522 let store = MemoryNonceStore::new();
523 let nonce = fresh_nonce();
524 assert!(store.try_consume(&nonce).unwrap());
525 assert!(!store.try_consume(&nonce).unwrap());
526 assert!(store.is_consumed(&nonce).unwrap());
527 }
528
529 #[test]
530 fn try_consume_concurrent_exactly_one_winner() {
531 use std::{sync::Arc, thread};
532 let store = Arc::new(MemoryNonceStore::new());
533 let nonce = fresh_nonce();
534 let handles: Vec<_> = (0..32)
535 .map(|_| {
536 let s = Arc::clone(&store);
537 thread::spawn(move || s.try_consume(&nonce).unwrap())
538 })
539 .collect();
540 let wins: usize = handles
541 .into_iter()
542 .map(|h| h.join().unwrap() as usize)
543 .sum();
544 assert_eq!(wins, 1);
545 }
546
547 #[test]
548 fn revoke_batch_marks_all_fingerprints() {
549 let store = MemoryRevocationStore::new();
550 let fps: Vec<[u8; 32]> = (0..8u8)
551 .map(|i| {
552 let mut f = [0u8; 32];
553 f[0] = i;
554 f
555 })
556 .collect();
557 store.revoke_batch(&fps).unwrap();
558 for fp in &fps {
559 assert!(store.is_revoked(fp).unwrap());
560 }
561 }
562
563 #[test]
564 fn health_check_returns_ok_for_memory_stores() {
565 assert!(MemoryRevocationStore::new().health_check().is_ok());
566 assert!(MemoryNonceStore::new().health_check().is_ok());
567 assert!(MemoryRateLimitStore::new().health_check().is_ok());
568 }
569
570 #[test]
571 fn rate_limit_enforces_window() {
572 let store = MemoryRateLimitStore::new();
573 let key = b"test-principal";
574 for _ in 0..5 {
575 assert!(
576 store.check_and_record(key, 5, 60).unwrap(),
577 "should be allowed"
578 );
579 }
580 assert!(
581 !store.check_and_record(key, 5, 60).unwrap(),
582 "should be blocked"
583 );
584 }
585}