use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use uuid::Uuid;
pub struct SessionGuard {
_inner: tokio::sync::OwnedMutexGuard<()>,
}
impl SessionGuard {
fn new(inner: tokio::sync::OwnedMutexGuard<()>) -> Self {
Self { _inner: inner }
}
}
impl std::fmt::Debug for SessionGuard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionGuard").finish_non_exhaustive()
}
}
#[derive(Debug, Clone, Default)]
pub struct SessionGate {
locks: Arc<RwLock<HashMap<Uuid, Arc<Mutex<()>>>>>,
}
impl SessionGate {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub async fn acquire(&self, session_id: Uuid) -> Result<SessionGuard, SessionBusy> {
let mutex = {
let read = self.locks.read().await;
read.get(&session_id).cloned()
};
let mutex = if let Some(m) = mutex {
m
} else {
let mut write = self.locks.write().await;
write
.entry(session_id)
.or_insert_with(|| Arc::new(Mutex::new(())))
.clone()
};
let guard = mutex
.try_lock_owned()
.map_err(|_| SessionBusy { session_id })?;
Ok(SessionGuard::new(guard))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SessionBusy {
pub session_id: Uuid,
}
impl std::fmt::Display for SessionBusy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"session {} is busy — another run is active",
self.session_id
)
}
}
impl std::error::Error for SessionBusy {}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[tokio::test]
async fn acquire_and_release() {
let gate = SessionGate::new();
let sid = Uuid::new_v4();
let guard = gate.acquire(sid).await.expect("should acquire");
drop(guard);
let _guard2 = gate
.acquire(sid)
.await
.expect("should re-acquire after drop");
}
#[tokio::test]
async fn concurrent_acquire_fails() {
let gate = SessionGate::new();
let sid = Uuid::new_v4();
let guard = gate.acquire(sid).await.expect("first acquire");
let result = gate.acquire(sid).await;
assert!(result.is_err(), "second acquire should fail");
assert_eq!(result.unwrap_err().session_id, sid);
drop(guard);
let _guard2 = gate
.acquire(sid)
.await
.expect("should succeed after release");
}
#[tokio::test]
async fn independent_sessions_no_contention() {
let gate = SessionGate::new();
let sid1 = Uuid::new_v4();
let sid2 = Uuid::new_v4();
let guard1 = gate.acquire(sid1).await.expect("acquire sid1");
let guard2 = gate.acquire(sid2).await.expect("acquire sid2");
drop(guard1);
drop(guard2);
}
#[tokio::test]
async fn sequential_runs_serialize() {
let gate = SessionGate::new();
let sid = Uuid::new_v4();
let g1 = gate.acquire(sid).await.expect("first");
drop(g1);
let g2 = gate.acquire(sid).await.expect("second");
drop(g2);
}
}