use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
use tracing::debug;
use nodedb_types::{DatabaseId, TenantId};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AdmissionError {
DatabaseCapExhausted { db: DatabaseId, limit: u32 },
TenantCapExhausted {
db: DatabaseId,
tenant: TenantId,
limit: u32,
},
}
impl std::fmt::Display for AdmissionError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::DatabaseCapExhausted { db, limit } => {
write!(
f,
"database {db:?} has reached its maximum connection limit ({limit})"
)
}
Self::TenantCapExhausted { db, tenant, limit } => {
write!(
f,
"tenant {tenant:?} in database {db:?} has reached its maximum \
connection limit ({limit})"
)
}
}
}
}
impl std::error::Error for AdmissionError {}
struct DbEntry {
semaphore: Arc<Semaphore>,
limit: u32,
}
struct TenantEntry {
semaphore: Arc<Semaphore>,
limit: u32,
}
pub struct AdmissionRegistry {
db_semaphores: RwLock<HashMap<DatabaseId, DbEntry>>,
tenant_semaphores: RwLock<HashMap<(DatabaseId, TenantId), TenantEntry>>,
}
impl AdmissionRegistry {
pub fn new() -> Self {
Self {
db_semaphores: RwLock::new(HashMap::new()),
tenant_semaphores: RwLock::new(HashMap::new()),
}
}
pub fn set_database_limit(&self, db: DatabaseId, limit: u32) {
let mut map = self
.db_semaphores
.write()
.unwrap_or_else(|p| p.into_inner());
if limit == 0 {
map.remove(&db);
} else {
map.insert(
db,
DbEntry {
semaphore: Arc::new(Semaphore::new(limit as usize)),
limit,
},
);
}
}
pub fn set_tenant_limit(&self, db: DatabaseId, tenant: TenantId, limit: u32) {
let mut map = self
.tenant_semaphores
.write()
.unwrap_or_else(|p| p.into_inner());
if limit == 0 {
map.remove(&(db, tenant));
} else {
map.insert(
(db, tenant),
TenantEntry {
semaphore: Arc::new(Semaphore::new(limit as usize)),
limit,
},
);
}
}
pub fn try_acquire_database(
&self,
db: DatabaseId,
) -> Result<Option<OwnedSemaphorePermit>, AdmissionError> {
let map = self.db_semaphores.read().unwrap_or_else(|p| p.into_inner());
let Some(entry) = map.get(&db) else {
return Ok(None); };
match entry.semaphore.clone().try_acquire_owned() {
Ok(permit) => {
debug!(db = ?db, "database admission permit acquired");
Ok(Some(permit))
}
Err(TryAcquireError::NoPermits) => Err(AdmissionError::DatabaseCapExhausted {
db,
limit: entry.limit,
}),
Err(TryAcquireError::Closed) => {
Ok(None)
}
}
}
pub fn try_acquire_tenant(
&self,
db: DatabaseId,
tenant: TenantId,
) -> Result<Option<OwnedSemaphorePermit>, AdmissionError> {
let map = self
.tenant_semaphores
.read()
.unwrap_or_else(|p| p.into_inner());
let Some(entry) = map.get(&(db, tenant)) else {
return Ok(None); };
match entry.semaphore.clone().try_acquire_owned() {
Ok(permit) => {
debug!(db = ?db, tenant = ?tenant, "tenant admission permit acquired");
Ok(Some(permit))
}
Err(TryAcquireError::NoPermits) => Err(AdmissionError::TenantCapExhausted {
db,
tenant,
limit: entry.limit,
}),
Err(TryAcquireError::Closed) => Ok(None),
}
}
}
impl Default for AdmissionRegistry {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for AdmissionRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let db_count = self.db_semaphores.read().map(|m| m.len()).unwrap_or(0);
let tenant_count = self.tenant_semaphores.read().map(|m| m.len()).unwrap_or(0);
f.debug_struct("AdmissionRegistry")
.field("db_entries", &db_count)
.field("tenant_entries", &tenant_count)
.finish()
}
}
#[cfg(test)]
mod tests {
use nodedb_types::{DatabaseId, TenantId};
use super::{AdmissionError, AdmissionRegistry};
fn db(n: u64) -> DatabaseId {
if n == 0 {
DatabaseId::DEFAULT
} else {
DatabaseId::DEFAULT
}
}
fn tenant(n: u64) -> TenantId {
TenantId::new(n)
}
#[test]
fn no_database_cap_allows_unlimited() {
let reg = AdmissionRegistry::new();
let r = reg.try_acquire_database(db(0));
assert!(r.unwrap().is_none());
}
#[test]
fn database_cap_allows_up_to_limit() {
let reg = AdmissionRegistry::new();
reg.set_database_limit(db(0), 2);
let p1 = reg.try_acquire_database(db(0)).unwrap();
let p2 = reg.try_acquire_database(db(0)).unwrap();
assert!(p1.is_some());
assert!(p2.is_some());
let err = reg.try_acquire_database(db(0)).unwrap_err();
assert!(matches!(
err,
AdmissionError::DatabaseCapExhausted { limit: 2, .. }
));
drop(p1);
let p3 = reg.try_acquire_database(db(0)).unwrap();
assert!(p3.is_some());
}
#[test]
fn tenant_cap_isolates_tenants() {
let reg = AdmissionRegistry::new();
reg.set_database_limit(db(0), 100); reg.set_tenant_limit(db(0), tenant(1), 1);
let t1_permit = reg.try_acquire_tenant(db(0), tenant(1)).unwrap();
assert!(t1_permit.is_some());
let err = reg.try_acquire_tenant(db(0), tenant(1)).unwrap_err();
assert!(matches!(
err,
AdmissionError::TenantCapExhausted { limit: 1, .. }
));
let t2_permit = reg.try_acquire_tenant(db(0), tenant(2)).unwrap();
assert!(t2_permit.is_none()); }
#[test]
fn set_limit_zero_removes_cap() {
let reg = AdmissionRegistry::new();
reg.set_database_limit(db(0), 1);
let _p = reg.try_acquire_database(db(0)).unwrap().unwrap();
reg.set_database_limit(db(0), 0);
let r = reg.try_acquire_database(db(0)).unwrap();
assert!(r.is_none());
}
}