use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct SupervisorInstanceGate {
max_concurrent: u32,
active_count: Arc<AtomicU32>,
}
impl SupervisorInstanceGate {
pub fn new(max_concurrent: u32) -> Self {
Self {
max_concurrent,
active_count: Arc::new(AtomicU32::new(0)),
}
}
pub fn try_acquire(&self) -> bool {
loop {
let current = self.active_count.load(Ordering::SeqCst);
if current >= self.max_concurrent {
return false;
}
match self.active_count.compare_exchange_weak(
current,
current + 1,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => return true,
Err(_) => continue, }
}
}
pub fn release(&self) {
let previous = self.active_count.fetch_sub(1, Ordering::SeqCst);
debug_assert!(previous > 0, "Released more slots than acquired");
}
pub fn get_active_count(&self) -> u32 {
self.active_count.load(Ordering::SeqCst)
}
pub fn get_max_concurrent(&self) -> u32 {
self.max_concurrent
}
pub fn is_saturated(&self) -> bool {
self.get_active_count() >= self.max_concurrent
}
}
#[derive(Debug, Clone)]
pub struct GroupLevelGate {
group_gates: Arc<Mutex<HashMap<String, Arc<AtomicU32>>>>,
max_per_group: u32,
}
impl GroupLevelGate {
pub fn new(max_per_group: u32) -> Self {
Self {
group_gates: Arc::new(Mutex::new(HashMap::new())),
max_per_group,
}
}
pub fn try_acquire_for_group(&self, group_id: &str) -> bool {
let mut gates = self.group_gates.lock().unwrap();
let gate = gates
.entry(group_id.to_string())
.or_insert_with(|| Arc::new(AtomicU32::new(0)));
loop {
let current = gate.load(Ordering::SeqCst);
if current >= self.max_per_group {
return false;
}
match gate.compare_exchange_weak(
current,
current + 1,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => return true,
Err(_) => continue,
}
}
}
pub fn release_for_group(&self, group_id: &str) {
let gates = self.group_gates.lock().unwrap();
if let Some(gate) = gates.get(group_id) {
let previous = gate.fetch_sub(1, Ordering::SeqCst);
debug_assert!(previous > 0, "Released more group slots than acquired");
}
}
pub fn get_active_count_for_group(&self, group_id: &str) -> u32 {
let gates = self.group_gates.lock().unwrap();
gates
.get(group_id)
.map(|g| g.load(Ordering::SeqCst))
.unwrap_or(0)
}
pub fn is_group_saturated(&self, group_id: &str) -> bool {
self.get_active_count_for_group(group_id) >= self.max_per_group
}
}
#[derive(Debug, Clone)]
pub struct CombinedThrottleGate {
instance_gate: SupervisorInstanceGate,
group_gate: Option<GroupLevelGate>,
}
impl CombinedThrottleGate {
pub fn new(instance_gate: SupervisorInstanceGate, group_gate: Option<GroupLevelGate>) -> Self {
Self {
instance_gate,
group_gate,
}
}
pub fn try_acquire(&self, group_id: Option<&str>) -> bool {
if !self.instance_gate.try_acquire() {
return false;
}
if let (Some(group_gate), Some(gid)) = (&self.group_gate, group_id)
&& !group_gate.try_acquire_for_group(gid)
{
self.instance_gate.release();
return false;
}
true
}
pub fn release(&self, group_id: Option<&str>) {
self.instance_gate.release();
if let (Some(group_gate), Some(gid)) = (&self.group_gate, group_id) {
group_gate.release_for_group(gid);
}
}
pub fn instance_gate(&self) -> &SupervisorInstanceGate {
&self.instance_gate
}
pub fn group_gate(&self) -> Option<&GroupLevelGate> {
self.group_gate.as_ref()
}
}
#[cfg(test)]
mod tests {
use crate::runtime::concurrent_gate::{
CombinedThrottleGate, GroupLevelGate, SupervisorInstanceGate,
};
#[test]
fn test_instance_gate_basic_acquire_release() {
let gate = SupervisorInstanceGate::new(3);
assert_eq!(gate.get_active_count(), 0);
assert!(gate.try_acquire());
assert_eq!(gate.get_active_count(), 1);
assert!(gate.try_acquire());
assert_eq!(gate.get_active_count(), 2);
gate.release();
assert_eq!(gate.get_active_count(), 1);
gate.release();
assert_eq!(gate.get_active_count(), 0);
}
#[test]
fn test_instance_gate_saturation() {
let gate = SupervisorInstanceGate::new(2);
assert!(gate.try_acquire());
assert!(gate.try_acquire());
assert!(!gate.try_acquire());
assert!(gate.is_saturated());
}
#[test]
fn test_group_gate_isolation() {
let gate = GroupLevelGate::new(2);
assert!(gate.try_acquire_for_group("group-a"));
assert!(gate.try_acquire_for_group("group-a"));
assert!(!gate.try_acquire_for_group("group-a"));
assert!(gate.try_acquire_for_group("group-b"));
assert_eq!(gate.get_active_count_for_group("group-b"), 1);
assert_eq!(gate.get_active_count_for_group("group-a"), 2);
}
#[test]
fn test_combined_gate_takes_stricter_verdict() {
let instance = SupervisorInstanceGate::new(5);
let group = GroupLevelGate::new(2);
let combined = CombinedThrottleGate::new(instance, Some(group));
assert!(combined.try_acquire(Some("test-group")));
assert!(combined.try_acquire(Some("test-group")));
assert!(!combined.try_acquire(Some("test-group"))); }
#[test]
fn test_combined_gate_without_group() {
let instance = SupervisorInstanceGate::new(2);
let combined = CombinedThrottleGate::new(instance, None);
assert!(combined.try_acquire(None));
assert!(combined.try_acquire(None));
assert!(!combined.try_acquire(None)); }
}