#[cfg(feature = "arc-swap")]
use arc_swap::ArcSwap;
#[cfg(feature = "parking-lot")]
use parking_lot::RwLock as ParkingRwLock;
#[cfg(not(any(feature = "arc-swap", feature = "parking-lot")))]
use std::sync::RwLock as StdRwLock;
#[cfg(feature = "arc-swap")]
use std::sync::Arc;
#[derive(Clone)]
struct SecretInner<const V: usize, const S: usize> {
keys: [Option<[u8; S]>; V],
current_version: u8,
}
pub trait SecretGroup<const V: usize = 256, const S: usize = 32>: Send + Sync {
fn current(&self) -> (u8, [u8; S]);
fn resolve(&self, version: u8) -> Option<[u8; S]>;
}
pub struct InMemorySecretGroup<const V: usize = 256, const S: usize = 32> {
#[cfg(feature = "arc-swap")]
inner: ArcSwap<SecretInner<V, S>>,
#[cfg(all(feature = "parking-lot", not(feature = "arc-swap")))]
inner: ParkingRwLock<SecretInner<V, S>>,
#[cfg(not(any(feature = "arc-swap", feature = "parking-lot")))]
inner: StdRwLock<SecretInner<V, S>>,
}
impl<const V: usize, const S: usize> InMemorySecretGroup<V, S> {
pub fn new(version: u8, initial_key: [u8; S]) -> Self {
assert!(
(version as usize) < V,
"version {} out of range for ring buffer of size {V}",
version
);
let mut keys: [Option<[u8; S]>; V] = std::array::from_fn(|_| None);
keys[version as usize] = Some(initial_key);
let inner_val = SecretInner {
keys,
current_version: version,
};
Self {
#[cfg(feature = "arc-swap")]
inner: ArcSwap::from_pointee(inner_val),
#[cfg(all(feature = "parking-lot", not(feature = "arc-swap")))]
inner: ParkingRwLock::new(inner_val),
#[cfg(not(any(feature = "arc-swap", feature = "parking-lot")))]
inner: StdRwLock::new(inner_val),
}
}
pub fn store_key(&self, version: u8, key: [u8; S]) {
assert!(
(version as usize) < V,
"version {} out of range for ring buffer of size {V}",
version
);
#[cfg(feature = "arc-swap")]
{
let mut inner = (**self.inner.load()).clone();
inner.keys[version as usize] = Some(key);
self.inner.store(Arc::new(inner));
}
#[cfg(all(feature = "parking-lot", not(feature = "arc-swap")))]
{
let mut inner = self.inner.write();
inner.keys[version as usize] = Some(key);
}
#[cfg(not(any(feature = "arc-swap", feature = "parking-lot")))]
{
let mut inner = self.inner.write().expect("lock poisoned");
inner.keys[version as usize] = Some(key);
}
}
pub fn promote(&self, version: u8) {
assert!(
(version as usize) < V,
"version {} out of range for ring buffer of size {V}",
version
);
#[cfg(feature = "arc-swap")]
{
let mut inner = (**self.inner.load()).clone();
if inner.keys[version as usize].is_none() {
panic!("cannot promote version {version} before it is stored");
}
inner.current_version = version;
self.inner.store(Arc::new(inner));
}
#[cfg(all(feature = "parking-lot", not(feature = "arc-swap")))]
{
let mut inner = self.inner.write();
if inner.keys[version as usize].is_none() {
panic!("cannot promote version {version} before it is stored");
}
inner.current_version = version;
}
#[cfg(not(any(feature = "arc-swap", feature = "parking-lot")))]
{
let mut inner = self.inner.write().expect("lock poisoned");
if inner.keys[version as usize].is_none() {
panic!("cannot promote version {version} before it is stored");
}
inner.current_version = version;
}
}
pub fn apply(&self, version: u8, key: [u8; S]) {
assert!(
(version as usize) < V,
"version {} out of range for ring buffer of size {V}",
version
);
#[cfg(feature = "arc-swap")]
{
let mut inner = (**self.inner.load()).clone();
inner.keys[version as usize] = Some(key);
inner.current_version = version;
self.inner.store(Arc::new(inner));
}
#[cfg(all(feature = "parking-lot", not(feature = "arc-swap")))]
{
let mut inner = self.inner.write();
inner.keys[version as usize] = Some(key);
inner.current_version = version;
}
#[cfg(not(any(feature = "arc-swap", feature = "parking-lot")))]
{
let mut inner = self.inner.write().expect("lock poisoned");
inner.keys[version as usize] = Some(key);
inner.current_version = version;
}
}
}
impl<const V: usize, const S: usize> SecretGroup<V, S> for InMemorySecretGroup<V, S> {
fn current(&self) -> (u8, [u8; S]) {
#[cfg(feature = "arc-swap")]
let (v, keys) = {
let inner = self.inner.load();
(inner.current_version, inner.keys)
};
#[cfg(all(feature = "parking-lot", not(feature = "arc-swap")))]
let (v, keys) = {
let inner = self.inner.read();
(inner.current_version, inner.keys)
};
#[cfg(not(any(feature = "arc-swap", feature = "parking-lot")))]
let (v, keys) = {
let inner = self.inner.read().expect("lock poisoned");
(inner.current_version, inner.keys)
};
let key = keys[v as usize].expect("current_version slot must always be populated");
(v, key)
}
fn resolve(&self, version: u8) -> Option<[u8; S]> {
#[cfg(feature = "arc-swap")]
return self.inner.load().keys[version as usize];
#[cfg(all(feature = "parking-lot", not(feature = "arc-swap")))]
return self.inner.read().keys[version as usize];
#[cfg(not(any(feature = "arc-swap", feature = "parking-lot")))]
return self.inner.read().expect("lock poisoned").keys[version as usize];
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
const KEY_A: [u8; 32] = [1u8; 32];
const KEY_B: [u8; 32] = [2u8; 32];
#[test]
fn new_returns_initial_key_as_current() {
let sg = InMemorySecretGroup::<256, 32>::new(0, KEY_A);
let (v, k) = sg.current();
assert_eq!(v, 0);
assert_eq!(k, KEY_A);
}
#[test]
fn resolve_returns_none_for_unpopulated_slot() {
let sg = InMemorySecretGroup::<256, 32>::new(0, KEY_A);
assert!(sg.resolve(1).is_none());
assert!(sg.resolve(255).is_none());
}
#[test]
fn resolve_returns_some_for_populated_slot() {
let sg = InMemorySecretGroup::<256, 32>::new(0, KEY_A);
assert_eq!(sg.resolve(0), Some(KEY_A));
}
#[test]
fn apply_updates_current_and_ring() {
let sg = InMemorySecretGroup::<256, 32>::new(0, KEY_A);
sg.apply(1, KEY_B);
let (v, k) = sg.current();
assert_eq!(v, 1);
assert_eq!(k, KEY_B);
assert_eq!(sg.resolve(0), Some(KEY_A));
assert_eq!(sg.resolve(1), Some(KEY_B));
}
#[tokio::test]
async fn concurrent_reads_during_apply_are_safe() {
let sg = Arc::new(InMemorySecretGroup::<256, 32>::new(0, KEY_A));
let sg2 = sg.clone();
let reader = tokio::spawn(async move {
for _ in 0..1000 {
let _ = sg2.current();
let _ = sg2.resolve(0);
let _ = sg2.resolve(1);
tokio::task::yield_now().await;
}
});
for i in 0u8..10 {
sg.apply(i, KEY_B);
}
reader.await.expect("reader must not panic");
}
}