droidsaw 2.0.0

DROIDSAW — unified Android reverse engineering CLI. Hermes, DEX, APK signing. JSON output, MCP server. Bytecode is not a security layer.
Documentation
//! Per-class concurrency semaphore + token-bucket rate limiter for MCP tools.
//!
//! **Problem:** The MCP server has no concurrency cap. `audit(mode=full)` is 1–15 min wall
//! on real APKs (semgrep + trufflehog subprocesses). Two concurrent long audits
//! on a 2-core machine saturate all tokio workers and block the entire server.
//! This module closes that gap with:
//!
//! 1. A per-class `tokio::sync::Semaphore` that limits concurrent in-flight calls.
//! 2. A per-class token-bucket rate limiter that limits calls per minute, closing
//!    the "burst saturate then quickly release" attack shape that bypasses a
//!    semaphore-only gate.
//!
//! **Design:**
//! - Classes are enumerated at compile time. New long-running tools get a class
//!   assignment; cheap read-only tools are `ReadOnly` (uncapped).
//! - Concurrency caps are set at `DroidsawServer::new()` time from CLI args;
//!   defaults are conservative (audit-full=1, decompile-all=1, taint=2).
//! - Rate limits use a simple token bucket: N tokens refilled every 60 seconds.
//!   The bucket is approximate (token refill is wall-clock; concurrent callers
//!   may see a race on the last token) but the error is bounded by concurrency.
//!
//! **Typed errors:**
//! `ConcurrencyRefused` carries a discriminable class name + observed cap.
//! `RateLimitRefused` carries the class + observed rate config.
//! Both convert to `rmcp::ErrorData` (MCP wire error) via `Into`.

use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};

use rmcp::ErrorData as McpError;

#[cfg(feature = "mcp")]
use tokio::sync::Semaphore;

/// Tool classification for concurrency + rate-limit purposes.
///
/// Add new variants here when a new long-running tool is introduced.
/// `ReadOnly` tools (info, manifest, strings, xrefs, query, triage, …)
/// are never concurrency-capped — they are cheap and the cap would harm
/// legitimate use.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ToolClass {
    /// `audit(mode=full|semgrep|trufflehog)` — spawns semgrep + trufflehog
    /// subprocesses; 1–15 min wall on real APKs.
    AuditFull,
    /// `decompile(all=true)` equivalent on large DEX inputs; CPU-heavy.
    DecompileAll,
    /// `taint` on large inputs. Less heavy than audit-full but still notable.
    TaintAnalyze,
    /// Any other tool not in the above classes. Default cap applies.
    Default,
    /// Cheap read-only tools (info, manifest, strings, xrefs, query, …).
    /// Never concurrency-capped.
    ReadOnly,
}

impl ToolClass {
    /// Human-readable class name for error messages and counters.
    pub fn as_str(self) -> &'static str {
        match self {
            ToolClass::AuditFull => "audit-full",
            ToolClass::DecompileAll => "decompile-all",
            ToolClass::TaintAnalyze => "taint-analyze",
            ToolClass::Default => "default",
            ToolClass::ReadOnly => "read-only",
        }
    }
}

/// Error returned when a tool call is refused by the concurrency gate.
#[derive(Debug)]
pub enum ConcurrencyRefused {
    /// Too many concurrent calls to this tool class.
    Concurrency {
        class: &'static str,
        max_concurrent: usize,
    },
    /// Rate limit exceeded for this tool class.
    RateLimit {
        class: &'static str,
        max_per_min: u32,
    },
}

impl ConcurrencyRefused {
    /// Convert to an MCP wire error. Code `-32000` is the JSON-RPC
    /// application error range; we embed the class name in the data
    /// field so callers can discriminate without parsing the message.
    pub fn into_mcp_error(self) -> McpError {
        match self {
            ConcurrencyRefused::Concurrency {
                class,
                max_concurrent,
            } => McpError::new(
                rmcp::model::ErrorCode(-32000),
                format!(
                    "concurrency limit exceeded for tool class {class:?}: \
                     max {max_concurrent} concurrent call(s) allowed. \
                     Retry after the in-flight call completes."
                ),
                Some(serde_json::json!({
                    "type": "ConcurrencyExceeded",
                    "class": class,
                    "max_concurrent": max_concurrent,
                })),
            ),
            ConcurrencyRefused::RateLimit { class, max_per_min } => McpError::new(
                rmcp::model::ErrorCode(-32000),
                format!(
                    "rate limit exceeded for tool class {class:?}: \
                     max {max_per_min} call(s) per minute allowed. \
                     Retry after the rate window resets."
                ),
                Some(serde_json::json!({
                    "type": "RateLimitExceeded",
                    "class": class,
                    "max_per_min": max_per_min,
                })),
            ),
        }
    }
}

/// Per-class token bucket state (mutex-guarded).
struct BucketState {
    /// Tokens currently available.
    tokens: u32,
    /// When the bucket was last refilled to `max_tokens`.
    last_refill: Instant,
    /// Maximum tokens (= max calls per minute).
    max_tokens: u32,
}

impl BucketState {
    fn new(max_per_min: u32) -> Self {
        Self {
            tokens: max_per_min,
            last_refill: Instant::now(),
            max_tokens: max_per_min,
        }
    }

    /// Try to consume one token. Returns `false` if the bucket is empty.
    /// Refills by one full `max_tokens` if at least 60 seconds have passed
    /// since the last refill.
    fn try_consume(&mut self) -> bool {
        let now = Instant::now();
        if now.duration_since(self.last_refill) >= Duration::from_secs(60) {
            self.tokens = self.max_tokens;
            self.last_refill = now;
        }
        if self.tokens > 0 {
            self.tokens = self.tokens.saturating_sub(1);
            true
        } else {
            false
        }
    }
}

/// Per-class concurrency + rate-limit guard.
///
/// Holds one `Semaphore` and one `BucketState` per capped class.
/// `ReadOnly` has no entry and is always admitted.
pub struct ClassGuard {
    semaphore: Arc<Semaphore>,
    bucket: Arc<Mutex<BucketState>>,
    max_concurrent: usize,
    max_per_min: u32,
    class: ToolClass,
}

impl ClassGuard {
    fn new(class: ToolClass, max_concurrent: usize, max_per_min: u32) -> Self {
        Self {
            semaphore: Arc::new(Semaphore::new(max_concurrent)),
            bucket: Arc::new(Mutex::new(BucketState::new(max_per_min))),
            max_concurrent,
            max_per_min,
            class,
        }
    }

    /// Attempt to acquire a concurrency slot and consume a rate-limit token.
    ///
    /// Both gates must pass for the call to proceed:
    /// 1. Rate-limit bucket has a token (fast, cheap check first).
    /// 2. Semaphore has a free permit (`try_acquire` — non-blocking).
    ///
    /// On success, returns a `SemaphorePermit` that must be held for the
    /// duration of the tool call and dropped when it completes.
    ///
    /// # Errors
    /// Returns `ConcurrencyRefused::RateLimit` if the bucket is empty.
    /// Returns `ConcurrencyRefused::Concurrency` if no semaphore permit
    /// is available.
    pub fn try_acquire(&self) -> Result<tokio::sync::SemaphorePermit<'_>, ConcurrencyRefused> {
        // Rate-limit check first (cheap mutex vs. semaphore).
        {
            let mut bucket = self.bucket.lock().unwrap_or_else(|e| e.into_inner());
            if !bucket.try_consume() {
                return Err(ConcurrencyRefused::RateLimit {
                    class: self.class.as_str(),
                    max_per_min: self.max_per_min,
                });
            }
        }

        // Concurrency check: non-blocking try_acquire.
        // `TryAcquireError` is `Closed` or `NoPermits`; both map to
        // "concurrency cap reached" from the caller's perspective and
        // carry no actionable detail beyond what we already report
        // (class + max_concurrent). Discarding it is intentional.
        #[allow(
            clippy::map_err_ignore,
            reason = "TryAcquireError carries no caller-actionable detail"
        )]
        self.semaphore.try_acquire().map_err(|_| {
            ConcurrencyRefused::Concurrency {
                class: self.class.as_str(),
                max_concurrent: self.max_concurrent,
            }
        })
    }
}

/// Configuration for all concurrency guards. Created from CLI args and
/// stored in `DroidsawServer`.
pub struct ConcurrencyConfig {
    pub audit_full: ClassGuard,
    pub decompile_all: ClassGuard,
    pub taint_analyze: ClassGuard,
    pub default: ClassGuard,
    /// Monotonic counter: total calls refused by semaphore across all classes.
    /// Exposed in audit envelopes as `mcp_concurrency_refused_total`.
    pub refused_total: Arc<std::sync::atomic::AtomicU64>,
}

impl ConcurrencyConfig {
    /// Create guards with the given per-class caps and a shared rate
    /// of `max_per_min` calls/minute per class.
    ///
    /// Suggested defaults (enforced by `DroidsawServer::new`):
    /// - `audit-full`: max 1 concurrent, 8/min.
    /// - `decompile-all`: max 1 concurrent, 8/min.
    /// - `taint-analyze`: max 2 concurrent, 8/min.
    /// - `default`: max 2 concurrent, 8/min.
    pub fn new(
        max_concurrent_audit: usize,
        max_concurrent_decompile: usize,
        max_concurrent_taint: usize,
        max_concurrent_default: usize,
        max_per_min: u32,
    ) -> Self {
        Self {
            audit_full: ClassGuard::new(ToolClass::AuditFull, max_concurrent_audit, max_per_min),
            decompile_all: ClassGuard::new(
                ToolClass::DecompileAll,
                max_concurrent_decompile,
                max_per_min,
            ),
            taint_analyze: ClassGuard::new(
                ToolClass::TaintAnalyze,
                max_concurrent_taint,
                max_per_min,
            ),
            default: ClassGuard::new(ToolClass::Default, max_concurrent_default, max_per_min),
            refused_total: Arc::new(std::sync::atomic::AtomicU64::new(0)),
        }
    }

    /// Acquire a guard for the given class, or return a typed `McpError`.
    ///
    /// `ReadOnly` always succeeds (returns `None`).
    /// Other classes try the semaphore + rate limit; on failure, increments
    /// `refused_total` and returns `McpError`.
    pub fn acquire(
        &self,
        class: ToolClass,
    ) -> Result<Option<tokio::sync::SemaphorePermit<'_>>, McpError> {
        let guard = match class {
            ToolClass::ReadOnly => return Ok(None),
            ToolClass::AuditFull => &self.audit_full,
            ToolClass::DecompileAll => &self.decompile_all,
            ToolClass::TaintAnalyze => &self.taint_analyze,
            ToolClass::Default => &self.default,
        };
        match guard.try_acquire() {
            Ok(permit) => Ok(Some(permit)),
            Err(refused) => {
                self.refused_total
                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
                Err(refused.into_mcp_error())
            }
        }
    }

    /// Current value of the refused counter. Surfaces in audit envelopes.
    pub fn refused_total(&self) -> u64 {
        self.refused_total.load(std::sync::atomic::Ordering::Relaxed)
    }
}

impl std::fmt::Debug for ConcurrencyConfig {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ConcurrencyConfig")
            .field(
                "audit_full.max_concurrent",
                &self.audit_full.max_concurrent,
            )
            .field(
                "decompile_all.max_concurrent",
                &self.decompile_all.max_concurrent,
            )
            .field(
                "taint_analyze.max_concurrent",
                &self.taint_analyze.max_concurrent,
            )
            .field("refused_total", &self.refused_total())
            .finish()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn test_cfg(max_concurrent: usize, max_per_min: u32) -> ConcurrencyConfig {
        ConcurrencyConfig::new(max_concurrent, max_concurrent, max_concurrent, max_concurrent, max_per_min)
    }

    #[tokio::test]
    async fn concurrency_cap_enforced() {
        // max=1: first acquire succeeds; second is refused.
        let cfg = test_cfg(1, 100);
        let p1 = cfg.acquire(ToolClass::AuditFull);
        assert!(p1.is_ok(), "first acquire should succeed");

        let p2 = cfg.acquire(ToolClass::AuditFull);
        assert!(p2.is_err(), "second acquire should fail (max=1 concurrent)");

        // After dropping the first permit, a new acquire should succeed.
        drop(p1);
        let p3 = cfg.acquire(ToolClass::AuditFull);
        assert!(p3.is_ok(), "acquire after drop should succeed");
    }

    #[tokio::test]
    async fn rate_limit_enforced() {
        // max_per_min=2: first two succeed; third is refused.
        let cfg = test_cfg(100, 2);
        let p1 = cfg.acquire(ToolClass::AuditFull);
        assert!(p1.is_ok());
        let p2 = cfg.acquire(ToolClass::AuditFull);
        assert!(p2.is_ok());
        let p3 = cfg.acquire(ToolClass::AuditFull);
        assert!(p3.is_err(), "third acquire should be rate-limited");
    }

    #[tokio::test]
    async fn readonly_always_admitted() {
        let cfg = test_cfg(0, 0); // max=0 would block everything except ReadOnly
        // ReadOnly bypasses both gates.
        let result = cfg.acquire(ToolClass::ReadOnly);
        assert!(result.is_ok());
        assert!(result.unwrap().is_none(), "ReadOnly returns None permit");
    }

    #[tokio::test]
    async fn refused_total_increments() {
        let cfg = test_cfg(1, 100);
        // First acquire consumes the slot.
        let _p = cfg.acquire(ToolClass::AuditFull);
        // Second attempt is refused.
        let _ = cfg.acquire(ToolClass::AuditFull);
        assert_eq!(cfg.refused_total(), 1);
    }
}