use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use dashmap::DashMap;
use tokio::sync::{OwnedSemaphorePermit, Semaphore, TryAcquireError};
pub const DEFAULT_PER_ACTOR_INFLIGHT_MAX: u32 = 16;
pub const DEFAULT_PER_ACTOR_BYTES_MAX: u64 = 4 * 1024 * 1024 * 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RejectReason {
InFlightCountExceeded { cap: u32 },
ByteBudgetExceeded { cap: u64, attempted: u64 },
}
impl std::fmt::Display for RejectReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RejectReason::InFlightCountExceeded { cap } => {
write!(f, "actor in-flight count cap {} exceeded", cap)
}
RejectReason::ByteBudgetExceeded { cap, attempted } => write!(
f,
"actor byte budget exceeded: would use {} bytes against cap {}",
attempted, cap
),
}
}
}
#[derive(Debug)]
pub(crate) struct ActorState {
in_flight_sem: Arc<Semaphore>,
bytes: AtomicU64,
byte_cap: u64,
inflight_cap: u32,
}
impl ActorState {
fn new(inflight_cap: u32, byte_cap: u64) -> Self {
Self {
in_flight_sem: Arc::new(Semaphore::new(inflight_cap as usize)),
bytes: AtomicU64::new(0),
byte_cap,
inflight_cap,
}
}
}
pub struct WorkloadController {
per_actor: DashMap<Arc<str>, Arc<ActorState>>,
inflight_cap: u32,
byte_cap: u64,
}
impl WorkloadController {
pub fn new(inflight_cap: u32, byte_cap: u64) -> Self {
Self {
per_actor: DashMap::new(),
inflight_cap,
byte_cap,
}
}
pub fn from_env() -> Self {
let inflight_cap = parse_env_u32(
"OMNIGRAPH_PER_ACTOR_INFLIGHT_MAX",
DEFAULT_PER_ACTOR_INFLIGHT_MAX,
);
let byte_cap = parse_env_u64("OMNIGRAPH_PER_ACTOR_BYTES_MAX", DEFAULT_PER_ACTOR_BYTES_MAX);
Self::new(inflight_cap, byte_cap)
}
pub fn with_defaults() -> Self {
Self::new(DEFAULT_PER_ACTOR_INFLIGHT_MAX, DEFAULT_PER_ACTOR_BYTES_MAX)
}
fn actor_state(&self, actor_id: &Arc<str>) -> Arc<ActorState> {
if let Some(existing) = self.per_actor.get(actor_id) {
return existing.clone();
}
self.per_actor
.entry(actor_id.clone())
.or_insert_with(|| Arc::new(ActorState::new(self.inflight_cap, self.byte_cap)))
.clone()
}
pub fn try_admit(
&self,
actor_id: &Arc<str>,
est_bytes: u64,
) -> Result<AdmissionGuard, RejectReason> {
let state = self.actor_state(actor_id);
let permit = match Arc::clone(&state.in_flight_sem).try_acquire_owned() {
Ok(permit) => permit,
Err(TryAcquireError::NoPermits) => {
return Err(RejectReason::InFlightCountExceeded {
cap: state.inflight_cap,
});
}
Err(TryAcquireError::Closed) => {
return Err(RejectReason::InFlightCountExceeded {
cap: state.inflight_cap,
});
}
};
let prior = state.bytes.fetch_add(est_bytes, Ordering::SeqCst);
let attempted = prior.saturating_add(est_bytes);
if attempted > state.byte_cap {
state.bytes.fetch_sub(est_bytes, Ordering::SeqCst);
return Err(RejectReason::ByteBudgetExceeded {
cap: state.byte_cap,
attempted,
});
}
Ok(AdmissionGuard {
_permit: permit,
actor_state: state,
est_bytes,
})
}
}
#[derive(Debug)]
pub struct AdmissionGuard {
_permit: OwnedSemaphorePermit,
actor_state: Arc<ActorState>,
est_bytes: u64,
}
impl Drop for AdmissionGuard {
fn drop(&mut self) {
self.actor_state
.bytes
.fetch_sub(self.est_bytes, Ordering::SeqCst);
}
}
fn parse_env_u32(name: &str, default: u32) -> u32 {
match std::env::var(name) {
Ok(v) => v.parse::<u32>().unwrap_or_else(|err| {
tracing::warn!(
env = name,
value = %v,
error = %err,
default,
"invalid env value, using default"
);
default
}),
Err(_) => default,
}
}
fn parse_env_u64(name: &str, default: u64) -> u64 {
match std::env::var(name) {
Ok(v) => v.parse::<u64>().unwrap_or_else(|err| {
tracing::warn!(
env = name,
value = %v,
error = %err,
default,
"invalid env value, using default"
);
default
}),
Err(_) => default,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn try_admit_admits_under_cap() {
let controller = WorkloadController::new(2, 1024);
let actor: Arc<str> = "alice".into();
let g1 = controller.try_admit(&actor, 100).expect("first admit");
let _g2 = controller.try_admit(&actor, 100).expect("second admit");
let err = controller
.try_admit(&actor, 100)
.expect_err("third should reject on count");
assert!(matches!(
err,
RejectReason::InFlightCountExceeded { cap: 2 }
));
drop(g1);
let _g3 = controller.try_admit(&actor, 100).expect("admit after drop");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn byte_budget_caps_admission() {
let controller = WorkloadController::new(16, 1000);
let actor: Arc<str> = "alice".into();
let _g1 = controller.try_admit(&actor, 600).expect("first admit");
let err = controller
.try_admit(&actor, 600)
.expect_err("second should reject on bytes");
match err {
RejectReason::ByteBudgetExceeded { cap, attempted } => {
assert_eq!(cap, 1000);
assert_eq!(attempted, 1200);
}
other => panic!("expected ByteBudgetExceeded, got {:?}", other),
}
let _g2 = controller.try_admit(&actor, 300).expect("smaller admit");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn actor_admission_race_does_not_exceed_cap() {
let controller = Arc::new(WorkloadController::new(16, u64::MAX / 4));
let actor: Arc<str> = "racer".into();
let (release_tx, _) = tokio::sync::broadcast::channel::<()>(1);
let mut handles = Vec::with_capacity(32);
for _ in 0..32 {
let controller = Arc::clone(&controller);
let actor = actor.clone();
let mut release_rx = release_tx.subscribe();
handles.push(tokio::spawn(async move {
let result = controller.try_admit(&actor, 1);
let success = result.is_ok();
let _guard = result.ok();
let _ = release_rx.recv().await;
success
}));
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let _ = release_tx.send(());
let mut accepted = 0u32;
let mut rejected = 0u32;
for h in handles {
if h.await.unwrap() {
accepted += 1;
} else {
rejected += 1;
}
}
assert_eq!(accepted, 16, "expected exactly 16 successful admits");
assert_eq!(rejected, 16, "expected exactly 16 rejections");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn per_actor_caps_independent() {
let controller = WorkloadController::new(1, 1024);
let alice: Arc<str> = "alice".into();
let bob: Arc<str> = "bob".into();
let _ga = controller.try_admit(&alice, 100).expect("alice ok");
let err = controller
.try_admit(&alice, 100)
.expect_err("alice rejected");
assert!(matches!(err, RejectReason::InFlightCountExceeded { .. }));
let _gb = controller.try_admit(&bob, 100).expect("bob ok");
}
}