use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use tokio::sync::Notify;
#[derive(Debug, Default)]
pub(super) struct CollectionState {
pub(super) open_scans: usize,
pub(super) draining: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScanStartError {
Draining,
}
impl std::fmt::Display for ScanStartError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Draining => f.write_str("collection is draining"),
}
}
}
impl std::error::Error for ScanStartError {}
#[derive(Debug, Default)]
pub struct CollectionQuiesce {
pub(super) inner: Mutex<Inner>,
pub(super) notify: Notify,
}
#[derive(Debug, Default)]
pub(super) struct Inner {
pub(super) states: HashMap<(u64, String), CollectionState>,
}
impl CollectionQuiesce {
pub fn new() -> Arc<Self> {
Arc::new(Self::default())
}
pub fn try_start_scan(
self: &Arc<Self>,
tenant_id: u64,
collection: &str,
) -> Result<ScanGuard, ScanStartError> {
let mut inner = self.inner.lock().expect("CollectionQuiesce mutex poisoned");
let entry = inner
.states
.entry((tenant_id, collection.to_string()))
.or_default();
if entry.draining {
return Err(ScanStartError::Draining);
}
entry.open_scans += 1;
Ok(ScanGuard {
registry: Arc::clone(self),
tenant_id,
collection: collection.to_string(),
released: false,
})
}
pub fn open_scans(&self, tenant_id: u64, collection: &str) -> usize {
let inner = self.inner.lock().expect("CollectionQuiesce mutex poisoned");
inner
.states
.get(&(tenant_id, collection.to_string()))
.map_or(0, |s| s.open_scans)
}
pub fn is_draining(&self, tenant_id: u64, collection: &str) -> bool {
let inner = self.inner.lock().expect("CollectionQuiesce mutex poisoned");
inner
.states
.get(&(tenant_id, collection.to_string()))
.is_some_and(|s| s.draining)
}
pub(super) fn release_scan(&self, tenant_id: u64, collection: &str) {
let mut inner = self.inner.lock().expect("CollectionQuiesce mutex poisoned");
if let Some(state) = inner.states.get_mut(&(tenant_id, collection.to_string())) {
debug_assert!(state.open_scans > 0, "release without matching acquire");
state.open_scans = state.open_scans.saturating_sub(1);
}
drop(inner);
self.notify.notify_waiters();
}
}
#[must_use = "ScanGuard must be held for the lifetime of the scan"]
pub struct ScanGuard {
registry: Arc<CollectionQuiesce>,
tenant_id: u64,
collection: String,
released: bool,
}
impl ScanGuard {
pub fn release(mut self) {
self.released = true;
self.registry.release_scan(self.tenant_id, &self.collection);
}
}
impl Drop for ScanGuard {
fn drop(&mut self) {
if !self.released {
self.registry.release_scan(self.tenant_id, &self.collection);
}
}
}
impl std::fmt::Debug for ScanGuard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScanGuard")
.field("tenant_id", &self.tenant_id)
.field("collection", &self.collection)
.field("released", &self.released)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn guard_increments_and_decrements() {
let q = CollectionQuiesce::new();
assert_eq!(q.open_scans(1, "c"), 0);
{
let _g = q.try_start_scan(1, "c").unwrap();
assert_eq!(q.open_scans(1, "c"), 1);
}
assert_eq!(q.open_scans(1, "c"), 0);
}
#[test]
fn multiple_concurrent_guards() {
let q = CollectionQuiesce::new();
let g1 = q.try_start_scan(1, "c").unwrap();
let g2 = q.try_start_scan(1, "c").unwrap();
let g3 = q.try_start_scan(1, "c").unwrap();
assert_eq!(q.open_scans(1, "c"), 3);
drop(g2);
assert_eq!(q.open_scans(1, "c"), 2);
drop(g1);
drop(g3);
assert_eq!(q.open_scans(1, "c"), 0);
}
#[test]
fn drain_rejects_new_scans() {
let q = CollectionQuiesce::new();
q.begin_drain(1, "c");
let err = q.try_start_scan(1, "c").unwrap_err();
assert_eq!(err, ScanStartError::Draining);
assert!(q.is_draining(1, "c"));
}
#[test]
fn drain_does_not_affect_other_collections() {
let q = CollectionQuiesce::new();
q.begin_drain(1, "c");
assert!(q.try_start_scan(1, "other").is_ok());
assert!(q.try_start_scan(2, "c").is_ok());
}
#[test]
fn explicit_release_matches_drop() {
let q = CollectionQuiesce::new();
let g = q.try_start_scan(1, "c").unwrap();
assert_eq!(q.open_scans(1, "c"), 1);
g.release();
assert_eq!(q.open_scans(1, "c"), 0);
}
}