use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, OwnedMutexGuard, RwLock};
use mempill_types::AgentId;
#[derive(Clone)]
pub struct AgentWriteLockMap {
locks: Arc<RwLock<HashMap<String, Arc<Mutex<()>>>>>,
}
impl AgentWriteLockMap {
pub fn new() -> Self {
Self {
locks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn acquire(&self, agent_id: &AgentId) -> OwnedMutexGuard<()> {
let lock = {
let read = self.locks.read().await;
read.get(&agent_id.0).cloned()
};
let lock = match lock {
Some(l) => l,
None => {
let mut write = self.locks.write().await;
write
.entry(agent_id.0.clone())
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
}
};
lock.lock_owned().await
}
}
impl Default for AgentWriteLockMap {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::time::{timeout, Duration};
#[tokio::test]
async fn same_agent_id_locks_are_mutually_exclusive() {
let map = AgentWriteLockMap::new();
let agent_a = AgentId("agent-A".into());
let counter = Arc::new(AtomicUsize::new(0));
let _guard1 = map.acquire(&agent_a).await;
counter.fetch_add(1, Ordering::SeqCst);
let map2 = map.clone();
let agent_a2 = agent_a.clone();
let counter2 = counter.clone();
let handle = tokio::spawn(async move {
let _guard2 = map2.acquire(&agent_a2).await;
counter2.fetch_add(10, Ordering::SeqCst);
});
tokio::time::sleep(Duration::from_millis(20)).await;
assert_eq!(
counter.load(Ordering::SeqCst),
1,
"second acquire must be blocked while first guard is held"
);
drop(_guard1);
timeout(Duration::from_millis(200), handle)
.await
.expect("task timed out")
.expect("task panicked");
assert_eq!(
counter.load(Ordering::SeqCst),
11,
"after releasing first guard, second acquire must complete"
);
}
#[tokio::test]
async fn different_agent_ids_proceed_independently() {
let map = AgentWriteLockMap::new();
let agent_a = AgentId("agent-A".into());
let agent_b = AgentId("agent-B".into());
let _guard_a = map.acquire(&agent_a).await;
let acquire_b = timeout(Duration::from_millis(100), map.acquire(&agent_b)).await;
assert!(
acquire_b.is_ok(),
"acquiring lock for agent-B must not block even when agent-A lock is held"
);
}
#[tokio::test]
async fn lock_map_is_cloneable_and_shares_state() {
let map = AgentWriteLockMap::new();
let map_clone = map.clone();
let agent = AgentId("agent-clone-test".into());
let _guard = map.acquire(&agent).await;
let acquire_via_clone = timeout(
Duration::from_millis(30),
map_clone.acquire(&agent),
)
.await;
assert!(
acquire_via_clone.is_err(),
"cloned map must share lock state — second acquire via clone must block"
);
}
#[tokio::test]
async fn guard_drops_release_the_lock() {
let map = AgentWriteLockMap::new();
let agent = AgentId("agent-drop-test".into());
{
let _guard = map.acquire(&agent).await;
}
let reacquire = timeout(Duration::from_millis(50), map.acquire(&agent)).await;
assert!(reacquire.is_ok(), "lock must be released after guard drop");
}
#[tokio::test]
async fn multiple_agents_independently_concurrent() {
let map = AgentWriteLockMap::new();
let mut handles = vec![];
for i in 0..8 {
let map_i = map.clone();
let agent_i = AgentId(format!("agent-{i}"));
handles.push(tokio::spawn(async move {
let _g = map_i.acquire(&agent_i).await;
tokio::time::sleep(Duration::from_millis(5)).await;
}));
}
let results = futures::future::join_all(handles).await;
for (i, r) in results.into_iter().enumerate() {
r.unwrap_or_else(|e| panic!("agent-{i} lock task panicked: {e}"));
}
}
}