use std::fmt::{Display, Formatter};
use actionqueue_core::ids::RunId;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ConcurrencyKey(String);
impl ConcurrencyKey {
pub fn new(key: impl Into<String>) -> Self {
let value = key.into();
assert!(!value.is_empty(), "ConcurrencyKey must not be empty");
Self(value)
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl From<String> for ConcurrencyKey {
fn from(value: String) -> Self {
Self::new(value)
}
}
impl From<&str> for ConcurrencyKey {
fn from(value: &str) -> Self {
Self::new(value.to_owned())
}
}
impl Display for ConcurrencyKey {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[must_use]
pub enum AcquireResult {
Acquired {
key: ConcurrencyKey,
run_id: RunId,
},
Occupied {
key: ConcurrencyKey,
holder_run_id: RunId,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[must_use]
pub enum ReleaseResult {
Released {
key: ConcurrencyKey,
},
NotHeld {
key: ConcurrencyKey,
attempting_run_id: RunId,
},
}
#[derive(Debug, Clone, Default)]
pub struct KeyGate {
occupied_keys: std::collections::HashMap<ConcurrencyKey, RunId>,
}
impl KeyGate {
pub fn new() -> Self {
Self { occupied_keys: std::collections::HashMap::new() }
}
pub fn acquire(&mut self, key: ConcurrencyKey, run_id: RunId) -> AcquireResult {
if let Some(holder) = self.occupied_keys.get(&key) {
if *holder == run_id {
tracing::debug!(%key, %run_id, "concurrency key re-acquired by same run");
AcquireResult::Acquired { key, run_id }
} else {
tracing::debug!(%key, %run_id, holder = %holder, "concurrency key occupied");
AcquireResult::Occupied { key, holder_run_id: *holder }
}
} else {
tracing::debug!(%key, %run_id, "concurrency key acquired");
self.occupied_keys.insert(key.clone(), run_id);
AcquireResult::Acquired { key, run_id }
}
}
pub fn release(&mut self, key: ConcurrencyKey, run_id: RunId) -> ReleaseResult {
match self.occupied_keys.get(&key) {
Some(holder) if *holder == run_id => {
self.occupied_keys.remove(&key);
tracing::debug!(%key, %run_id, "concurrency key released");
ReleaseResult::Released { key }
}
Some(_holder) => {
ReleaseResult::NotHeld { key, attempting_run_id: run_id }
}
None => {
ReleaseResult::NotHeld { key, attempting_run_id: run_id }
}
}
}
pub fn is_key_occupied(&self, key: &ConcurrencyKey) -> bool {
self.occupied_keys.contains_key(key)
}
pub fn key_holder(&self, key: &ConcurrencyKey) -> Option<RunId> {
self.occupied_keys.get(key).copied()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn acquire_succeeds_when_key_is_free() {
let mut gate = KeyGate::new();
let key = ConcurrencyKey::new("my-key");
let run_id = RunId::new();
let result = gate.acquire(key.clone(), run_id);
match result {
AcquireResult::Acquired { key: acquired_key, run_id: acquired_run_id } => {
assert_eq!(acquired_key, key);
assert_eq!(acquired_run_id, run_id);
}
AcquireResult::Occupied { .. } => panic!("Expected acquisition to succeed"),
}
assert!(gate.is_key_occupied(&key));
assert_eq!(gate.key_holder(&key), Some(run_id));
}
#[test]
fn acquire_fails_when_key_is_occupied_by_different_run() {
let mut gate = KeyGate::new();
let key = ConcurrencyKey::new("my-key");
let holder_run_id = RunId::new();
let requesting_run_id = RunId::new();
let _ = gate.acquire(key.clone(), holder_run_id);
let result = gate.acquire(key.clone(), requesting_run_id);
match result {
AcquireResult::Occupied { key: occupied_key, holder_run_id: occupied_holder } => {
assert_eq!(occupied_key, key);
assert_eq!(occupied_holder, holder_run_id);
}
AcquireResult::Acquired { .. } => panic!("Expected acquisition to fail"),
}
assert_eq!(gate.key_holder(&key), Some(holder_run_id));
}
#[test]
fn acquire_succeeds_when_same_run_reacquires_key() {
let mut gate = KeyGate::new();
let key = ConcurrencyKey::new("my-key");
let run_id = RunId::new();
let _ = gate.acquire(key.clone(), run_id);
let result = gate.acquire(key.clone(), run_id);
match result {
AcquireResult::Acquired { key: acquired_key, run_id: acquired_run_id } => {
assert_eq!(acquired_key, key);
assert_eq!(acquired_run_id, run_id);
}
AcquireResult::Occupied { .. } => panic!("Expected re-acquisition to succeed"),
}
assert!(gate.is_key_occupied(&key));
assert_eq!(gate.key_holder(&key), Some(run_id));
}
#[test]
fn release_releases_key_held_by_same_run() {
let mut gate = KeyGate::new();
let key = ConcurrencyKey::new("my-key");
let run_id = RunId::new();
let _ = gate.acquire(key.clone(), run_id);
let result = gate.release(key.clone(), run_id);
match result {
ReleaseResult::Released { key: released_key } => {
assert_eq!(released_key, key);
}
ReleaseResult::NotHeld { .. } => panic!("Expected release to succeed"),
}
assert!(!gate.is_key_occupied(&key));
assert_eq!(gate.key_holder(&key), None);
}
#[test]
fn release_fails_when_key_held_by_different_run() {
let mut gate = KeyGate::new();
let key = ConcurrencyKey::new("my-key");
let holder_run_id = RunId::new();
let attempting_run_id = RunId::new();
let _ = gate.acquire(key.clone(), holder_run_id);
let result = gate.release(key.clone(), attempting_run_id);
match result {
ReleaseResult::NotHeld { key: released_key, attempting_run_id: attempted_run_id } => {
assert_eq!(released_key, key);
assert_eq!(attempted_run_id, attempting_run_id);
}
ReleaseResult::Released { .. } => panic!("Expected release to fail"),
}
assert!(gate.is_key_occupied(&key));
assert_eq!(gate.key_holder(&key), Some(holder_run_id));
}
#[test]
fn release_has_no_effect_when_key_is_free() {
let mut gate = KeyGate::new();
let key = ConcurrencyKey::new("my-key");
let run_id = RunId::new();
let result = gate.release(key.clone(), run_id);
match result {
ReleaseResult::NotHeld { .. } => {}
ReleaseResult::Released { .. } => panic!("Expected release to have no effect"),
}
assert!(!gate.is_key_occupied(&key));
}
#[test]
fn different_keys_can_be_occupied_simultaneously() {
let mut gate = KeyGate::new();
let key1 = ConcurrencyKey::new("key-1");
let key2 = ConcurrencyKey::new("key-2");
let run1_id = RunId::new();
let run2_id = RunId::new();
let _ = gate.acquire(key1.clone(), run1_id);
let _ = gate.acquire(key2.clone(), run2_id);
assert!(gate.is_key_occupied(&key1));
assert!(gate.is_key_occupied(&key2));
assert_eq!(gate.key_holder(&key1), Some(run1_id));
assert_eq!(gate.key_holder(&key2), Some(run2_id));
}
}