use std::sync::Arc;
use tokio::sync::{Mutex, Semaphore};
#[derive(Clone)]
pub struct ConcurrencyController {
global: Arc<Semaphore>,
quotas: ModeQuotas,
running: Arc<Mutex<RunningCounts>>,
fairness: FairnessPolicy,
}
#[derive(Debug, Clone, Copy)]
pub struct ModeQuotas {
pub time_max: usize,
pub issue_max: usize,
}
#[derive(Debug, Default)]
struct RunningCounts {
time_driven: usize,
issue_driven: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FairnessPolicy {
RoundRobin,
Priority,
Proportional,
}
impl std::str::FromStr for FairnessPolicy {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"round_robin" | "round-robin" | "roundrobin" => Ok(FairnessPolicy::RoundRobin),
"priority" => Ok(FairnessPolicy::Priority),
"proportional" => Ok(FairnessPolicy::Proportional),
_ => Err(format!("unknown fairness policy: {}", s)),
}
}
}
pub struct AgentPermit {
_global: tokio::sync::OwnedSemaphorePermit,
mode: AgentMode,
running: Arc<Mutex<RunningCounts>>,
}
#[derive(Debug, Clone, Copy)]
enum AgentMode {
TimeDriven,
IssueDriven,
}
impl Drop for AgentPermit {
fn drop(&mut self) {
let mode = self.mode;
let running = self.running.clone();
tokio::spawn(async move {
let mut counts = running.lock().await;
match mode {
AgentMode::TimeDriven => counts.time_driven -= 1,
AgentMode::IssueDriven => counts.issue_driven -= 1,
}
});
}
}
impl ConcurrencyController {
pub fn new(global_max: usize, quotas: ModeQuotas, fairness: FairnessPolicy) -> Self {
Self {
global: Arc::new(Semaphore::new(global_max)),
quotas,
running: Arc::new(Mutex::new(RunningCounts::default())),
fairness,
}
}
pub async fn acquire_time_driven(&self) -> Option<AgentPermit> {
self.acquire(AgentMode::TimeDriven).await
}
pub async fn acquire_issue_driven(&self) -> Option<AgentPermit> {
self.acquire(AgentMode::IssueDriven).await
}
pub async fn running_counts(&self) -> (usize, usize) {
let counts = self.running.lock().await;
(counts.time_driven, counts.issue_driven)
}
pub fn available_slots(&self) -> usize {
self.global.available_permits()
}
async fn mode_has_capacity(&self, mode: AgentMode) -> bool {
let counts = self.running.lock().await;
match mode {
AgentMode::TimeDriven => counts.time_driven < self.quotas.time_max,
AgentMode::IssueDriven => counts.issue_driven < self.quotas.issue_max,
}
}
pub fn fairness_policy(&self) -> FairnessPolicy {
self.fairness
}
async fn acquire(&self, mode: AgentMode) -> Option<AgentPermit> {
if !self.mode_has_capacity(mode).await {
tracing::debug!(?mode, "mode quota exceeded");
return None;
}
if self.fairness == FairnessPolicy::Proportional {
let counts = self.running.lock().await;
let total = counts.time_driven + counts.issue_driven;
let global_cap = self.global.available_permits() + total;
if global_cap > 0 {
let mode_count = match mode {
AgentMode::TimeDriven => counts.time_driven,
AgentMode::IssueDriven => counts.issue_driven,
};
let mode_quota = match mode {
AgentMode::TimeDriven => self.quotas.time_max,
AgentMode::IssueDriven => self.quotas.issue_max,
};
let total_quota = self.quotas.time_max + self.quotas.issue_max;
let fair_share = (global_cap * mode_quota) / total_quota.max(1);
if mode_count >= fair_share && fair_share > 0 {
tracing::debug!(?mode, mode_count, fair_share, "proportional fairness limit");
return None;
}
}
}
let global_permit = match self.global.clone().try_acquire_owned() {
Ok(p) => p,
Err(_) => {
tracing::debug!("global concurrency limit reached");
return None;
}
};
{
let mut counts = self.running.lock().await;
match mode {
AgentMode::TimeDriven => counts.time_driven += 1,
AgentMode::IssueDriven => counts.issue_driven += 1,
}
}
tracing::debug!(?mode, "acquired concurrency slot");
Some(AgentPermit {
_global: global_permit,
mode,
running: self.running.clone(),
})
}
}
impl Default for ModeQuotas {
fn default() -> Self {
Self {
time_max: 3,
issue_max: 2,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_acquire_release() {
let controller = ConcurrencyController::new(
2,
ModeQuotas {
time_max: 2,
issue_max: 2,
},
FairnessPolicy::RoundRobin,
);
let permit1 = controller.acquire_time_driven().await;
assert!(permit1.is_some());
let permit2 = controller.acquire_time_driven().await;
assert!(permit2.is_some());
let permit3 = controller.acquire_time_driven().await;
assert!(permit3.is_none());
drop(permit1);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let permit4 = controller.acquire_time_driven().await;
assert!(permit4.is_some());
}
#[tokio::test]
async fn test_mode_quotas() {
let controller = ConcurrencyController::new(
10,
ModeQuotas {
time_max: 1,
issue_max: 1,
},
FairnessPolicy::RoundRobin,
);
let time_permit = controller.acquire_time_driven().await;
assert!(time_permit.is_some());
let time_permit2 = controller.acquire_time_driven().await;
assert!(time_permit2.is_none());
let issue_permit = controller.acquire_issue_driven().await;
assert!(issue_permit.is_some());
let issue_permit2 = controller.acquire_issue_driven().await;
assert!(issue_permit2.is_none());
}
#[tokio::test]
async fn test_running_counts() {
let controller = ConcurrencyController::new(
5,
ModeQuotas {
time_max: 3,
issue_max: 3,
},
FairnessPolicy::RoundRobin,
);
let _time_permit = controller.acquire_time_driven().await.unwrap();
let _issue_permit = controller.acquire_issue_driven().await.unwrap();
let (time_count, issue_count) = controller.running_counts().await;
assert_eq!(time_count, 1);
assert_eq!(issue_count, 1);
}
#[test]
fn test_fairness_policy_from_str() {
assert_eq!(
"round_robin".parse::<FairnessPolicy>().unwrap(),
FairnessPolicy::RoundRobin
);
assert_eq!(
"priority".parse::<FairnessPolicy>().unwrap(),
FairnessPolicy::Priority
);
assert_eq!(
"proportional".parse::<FairnessPolicy>().unwrap(),
FairnessPolicy::Proportional
);
assert!("unknown".parse::<FairnessPolicy>().is_err());
}
}