use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use rmcp::ErrorData as McpError;
#[cfg(feature = "mcp")]
use tokio::sync::Semaphore;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ToolClass {
AuditFull,
DecompileAll,
TaintAnalyze,
Default,
ReadOnly,
}
impl ToolClass {
pub fn as_str(self) -> &'static str {
match self {
ToolClass::AuditFull => "audit-full",
ToolClass::DecompileAll => "decompile-all",
ToolClass::TaintAnalyze => "taint-analyze",
ToolClass::Default => "default",
ToolClass::ReadOnly => "read-only",
}
}
}
#[derive(Debug)]
pub enum ConcurrencyRefused {
Concurrency {
class: &'static str,
max_concurrent: usize,
},
RateLimit {
class: &'static str,
max_per_min: u32,
},
}
impl ConcurrencyRefused {
pub fn into_mcp_error(self) -> McpError {
match self {
ConcurrencyRefused::Concurrency {
class,
max_concurrent,
} => McpError::new(
rmcp::model::ErrorCode(-32000),
format!(
"concurrency limit exceeded for tool class {class:?}: \
max {max_concurrent} concurrent call(s) allowed. \
Retry after the in-flight call completes."
),
Some(serde_json::json!({
"type": "ConcurrencyExceeded",
"class": class,
"max_concurrent": max_concurrent,
})),
),
ConcurrencyRefused::RateLimit { class, max_per_min } => McpError::new(
rmcp::model::ErrorCode(-32000),
format!(
"rate limit exceeded for tool class {class:?}: \
max {max_per_min} call(s) per minute allowed. \
Retry after the rate window resets."
),
Some(serde_json::json!({
"type": "RateLimitExceeded",
"class": class,
"max_per_min": max_per_min,
})),
),
}
}
}
struct BucketState {
tokens: u32,
last_refill: Instant,
max_tokens: u32,
}
impl BucketState {
fn new(max_per_min: u32) -> Self {
Self {
tokens: max_per_min,
last_refill: Instant::now(),
max_tokens: max_per_min,
}
}
fn try_consume(&mut self) -> bool {
let now = Instant::now();
if now.duration_since(self.last_refill) >= Duration::from_secs(60) {
self.tokens = self.max_tokens;
self.last_refill = now;
}
if self.tokens > 0 {
self.tokens = self.tokens.saturating_sub(1);
true
} else {
false
}
}
}
pub struct ClassGuard {
semaphore: Arc<Semaphore>,
bucket: Arc<Mutex<BucketState>>,
max_concurrent: usize,
max_per_min: u32,
class: ToolClass,
}
impl ClassGuard {
fn new(class: ToolClass, max_concurrent: usize, max_per_min: u32) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(max_concurrent)),
bucket: Arc::new(Mutex::new(BucketState::new(max_per_min))),
max_concurrent,
max_per_min,
class,
}
}
pub fn try_acquire(&self) -> Result<tokio::sync::SemaphorePermit<'_>, ConcurrencyRefused> {
{
let mut bucket = self.bucket.lock().unwrap_or_else(|e| e.into_inner());
if !bucket.try_consume() {
return Err(ConcurrencyRefused::RateLimit {
class: self.class.as_str(),
max_per_min: self.max_per_min,
});
}
}
#[allow(
clippy::map_err_ignore,
reason = "TryAcquireError carries no caller-actionable detail"
)]
self.semaphore.try_acquire().map_err(|_| {
ConcurrencyRefused::Concurrency {
class: self.class.as_str(),
max_concurrent: self.max_concurrent,
}
})
}
}
pub struct ConcurrencyConfig {
pub audit_full: ClassGuard,
pub decompile_all: ClassGuard,
pub taint_analyze: ClassGuard,
pub default: ClassGuard,
pub refused_total: Arc<std::sync::atomic::AtomicU64>,
}
impl ConcurrencyConfig {
pub fn new(
max_concurrent_audit: usize,
max_concurrent_decompile: usize,
max_concurrent_taint: usize,
max_concurrent_default: usize,
max_per_min: u32,
) -> Self {
Self {
audit_full: ClassGuard::new(ToolClass::AuditFull, max_concurrent_audit, max_per_min),
decompile_all: ClassGuard::new(
ToolClass::DecompileAll,
max_concurrent_decompile,
max_per_min,
),
taint_analyze: ClassGuard::new(
ToolClass::TaintAnalyze,
max_concurrent_taint,
max_per_min,
),
default: ClassGuard::new(ToolClass::Default, max_concurrent_default, max_per_min),
refused_total: Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
pub fn acquire(
&self,
class: ToolClass,
) -> Result<Option<tokio::sync::SemaphorePermit<'_>>, McpError> {
let guard = match class {
ToolClass::ReadOnly => return Ok(None),
ToolClass::AuditFull => &self.audit_full,
ToolClass::DecompileAll => &self.decompile_all,
ToolClass::TaintAnalyze => &self.taint_analyze,
ToolClass::Default => &self.default,
};
match guard.try_acquire() {
Ok(permit) => Ok(Some(permit)),
Err(refused) => {
self.refused_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Err(refused.into_mcp_error())
}
}
}
pub fn refused_total(&self) -> u64 {
self.refused_total.load(std::sync::atomic::Ordering::Relaxed)
}
}
impl std::fmt::Debug for ConcurrencyConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConcurrencyConfig")
.field(
"audit_full.max_concurrent",
&self.audit_full.max_concurrent,
)
.field(
"decompile_all.max_concurrent",
&self.decompile_all.max_concurrent,
)
.field(
"taint_analyze.max_concurrent",
&self.taint_analyze.max_concurrent,
)
.field("refused_total", &self.refused_total())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_cfg(max_concurrent: usize, max_per_min: u32) -> ConcurrencyConfig {
ConcurrencyConfig::new(max_concurrent, max_concurrent, max_concurrent, max_concurrent, max_per_min)
}
#[tokio::test]
async fn concurrency_cap_enforced() {
let cfg = test_cfg(1, 100);
let p1 = cfg.acquire(ToolClass::AuditFull);
assert!(p1.is_ok(), "first acquire should succeed");
let p2 = cfg.acquire(ToolClass::AuditFull);
assert!(p2.is_err(), "second acquire should fail (max=1 concurrent)");
drop(p1);
let p3 = cfg.acquire(ToolClass::AuditFull);
assert!(p3.is_ok(), "acquire after drop should succeed");
}
#[tokio::test]
async fn rate_limit_enforced() {
let cfg = test_cfg(100, 2);
let p1 = cfg.acquire(ToolClass::AuditFull);
assert!(p1.is_ok());
let p2 = cfg.acquire(ToolClass::AuditFull);
assert!(p2.is_ok());
let p3 = cfg.acquire(ToolClass::AuditFull);
assert!(p3.is_err(), "third acquire should be rate-limited");
}
#[tokio::test]
async fn readonly_always_admitted() {
let cfg = test_cfg(0, 0); let result = cfg.acquire(ToolClass::ReadOnly);
assert!(result.is_ok());
assert!(result.unwrap().is_none(), "ReadOnly returns None permit");
}
#[tokio::test]
async fn refused_total_increments() {
let cfg = test_cfg(1, 100);
let _p = cfg.acquire(ToolClass::AuditFull);
let _ = cfg.acquire(ToolClass::AuditFull);
assert_eq!(cfg.refused_total(), 1);
}
}