use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::{Mutex as AsyncMutex, OwnedMutexGuard};
pub(crate) type TableQueueKey = (String, Option<String>);
#[derive(Default)]
pub(crate) struct WriteQueueManager {
queues: Mutex<HashMap<TableQueueKey, Arc<AsyncMutex<()>>>>,
}
impl WriteQueueManager {
pub(crate) fn new() -> Self {
Self::default()
}
fn slot(&self, key: &TableQueueKey) -> Arc<AsyncMutex<()>> {
let mut map = self.queues.lock().expect("write queue map poisoned");
if let Some(existing) = map.get(key) {
return Arc::clone(existing);
}
let fresh = Arc::new(AsyncMutex::new(()));
map.insert(key.clone(), Arc::clone(&fresh));
fresh
}
pub(crate) async fn acquire(&self, key: &TableQueueKey) -> OwnedMutexGuard<()> {
self.slot(key).lock_owned().await
}
pub(crate) async fn acquire_many(
&self,
keys: &[TableQueueKey],
) -> Vec<OwnedMutexGuard<()>> {
if keys.is_empty() {
return Vec::new();
}
let mut sorted: Vec<TableQueueKey> = keys.to_vec();
sorted.sort();
sorted.dedup();
let mut guards = Vec::with_capacity(sorted.len());
for key in &sorted {
guards.push(self.acquire(key).await);
}
guards
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::{Duration, Instant};
use tokio::time::timeout;
fn key(table: &str, branch: Option<&str>) -> TableQueueKey {
(table.to_string(), branch.map(str::to_string))
}
#[tokio::test]
async fn acquire_many_empty_returns_empty() {
let qm = WriteQueueManager::new();
let guards = qm.acquire_many(&[]).await;
assert!(guards.is_empty());
}
#[tokio::test]
async fn acquire_many_dedupes_repeated_keys() {
let qm = WriteQueueManager::new();
let k = key("t1", None);
let guards = timeout(
Duration::from_secs(2),
qm.acquire_many(&[k.clone(), k.clone(), k]),
)
.await
.expect("acquire_many with duplicates deadlocked");
assert_eq!(guards.len(), 1);
}
#[tokio::test]
async fn acquire_many_sorts_keys_deterministically() {
let qm = Arc::new(WriteQueueManager::new());
let a = key("a", None);
let z = key("z", None);
let _held = qm.acquire(&a).await;
let qm2 = Arc::clone(&qm);
let z_clone = z.clone();
let a_clone = a.clone();
let result = timeout(Duration::from_millis(200), async move {
qm2.acquire_many(&[z_clone, a_clone]).await
})
.await;
assert!(result.is_err(), "acquire_many should block on `a`, the lex-first key");
}
#[tokio::test]
async fn same_key_acquire_serializes() {
let qm = Arc::new(WriteQueueManager::new());
let k = key("t1", None);
let first = qm.acquire(&k).await;
let qm2 = Arc::clone(&qm);
let k2 = k.clone();
let blocked = timeout(Duration::from_millis(200), async move {
qm2.acquire(&k2).await
})
.await;
assert!(blocked.is_err(), "second acquire on same key must block");
drop(first);
let _second = timeout(Duration::from_secs(2), qm.acquire(&k))
.await
.expect("second acquire after release should not block");
}
#[tokio::test]
async fn disjoint_keys_acquire_concurrently() {
let qm = Arc::new(WriteQueueManager::new());
let a = key("a", None);
let b = key("b", None);
let _held_a = qm.acquire(&a).await;
let qm2 = Arc::clone(&qm);
let start = Instant::now();
let _held_b = timeout(Duration::from_secs(2), qm2.acquire(&b))
.await
.expect("disjoint key acquire must not block on unrelated held key");
assert!(
start.elapsed() < Duration::from_millis(500),
"disjoint acquire took {:?}, should be near-instant",
start.elapsed()
);
}
#[tokio::test]
async fn disjoint_branches_on_same_table_do_not_serialize() {
let qm = Arc::new(WriteQueueManager::new());
let main_k = key("t1", None);
let feature_k = key("t1", Some("feature"));
let _held_main = qm.acquire(&main_k).await;
let _held_feature = timeout(Duration::from_secs(2), qm.acquire(&feature_k))
.await
.expect("same-table-different-branch should not serialize");
}
}