use crate::mcp::config::McpConfig;
use crate::mcp::error::{McpError, McpResult};
use std::collections::VecDeque;
use std::io::{IsTerminal, Write};
use std::sync::Mutex;
use std::time::{Duration, Instant};
pub struct RateLimiter {
check_times: Mutex<VecDeque<Instant>>,
install_times: Mutex<VecDeque<Instant>>,
max_checks_per_minute: u32,
max_installs_per_minute: u32,
}
impl RateLimiter {
pub fn new(config: &McpConfig) -> Self {
Self {
check_times: Mutex::new(VecDeque::new()),
install_times: Mutex::new(VecDeque::new()),
max_checks_per_minute: config.mcp.max_checks_per_minute,
max_installs_per_minute: config.mcp.max_installs_per_minute,
}
}
pub fn check_check_limit(&self) -> McpResult<()> {
self.check_limit(
&self.check_times,
self.max_checks_per_minute,
"Tool check rate limit exceeded. Please wait before checking more tools.",
)
}
pub fn check_install_limit(&self) -> McpResult<()> {
self.check_limit(
&self.install_times,
self.max_installs_per_minute,
"Install rate limit exceeded. Please wait before installing more tools.",
)
}
fn check_limit(
&self,
times: &Mutex<VecDeque<Instant>>,
max_per_minute: u32,
error_message: &str,
) -> McpResult<()> {
let mut times = times
.lock()
.map_err(|_| McpError::internal_error("Lock poisoned"))?;
let now = Instant::now();
let one_minute_ago = now - Duration::from_secs(60);
while times.front().is_some_and(|&t| t < one_minute_ago) {
times.pop_front();
}
if times.len() >= max_per_minute as usize {
return Err(McpError::rate_limited(error_message));
}
times.push_back(now);
Ok(())
}
#[allow(dead_code)] pub fn check_count(&self) -> usize {
self.get_count(&self.check_times)
}
#[allow(dead_code)] pub fn install_count(&self) -> usize {
self.get_count(&self.install_times)
}
#[allow(dead_code)] fn get_count(&self, times: &Mutex<VecDeque<Instant>>) -> usize {
let times = match times.lock() {
Ok(t) => t,
Err(_) => return 0,
};
let one_minute_ago = Instant::now() - Duration::from_secs(60);
times.iter().filter(|&&t| t >= one_minute_ago).count()
}
}
pub fn check_allowlist(tool: &str, config: &McpConfig) -> McpResult<()> {
if !config.is_allowed(tool) {
if config.is_denied(tool) {
return Err(McpError::tool_denied(tool));
}
return Err(McpError::tool_not_allowed(tool));
}
Ok(())
}
pub fn prompt_user_confirmation(
tool_name: &str,
command: &str,
client_name: Option<&str>,
) -> McpResult<ConfirmationResult> {
if !std::io::stderr().is_terminal() {
return Err(McpError::user_cancelled());
}
let mut stderr = std::io::stderr();
writeln!(stderr)?;
writeln!(
stderr,
"┌────────────────────────────────────────────────────"
)?;
writeln!(stderr, "│ Jarvy MCP: Install {}?", tool_name)?;
writeln!(stderr, "│")?;
writeln!(stderr, "│ This will execute:")?;
writeln!(stderr, "│ {}", command)?;
writeln!(stderr, "│")?;
if let Some(client) = client_name {
writeln!(stderr, "│ Requested by: {}", client)?;
writeln!(stderr, "│")?;
}
writeln!(stderr, "│ [Y]es / [N]o / [A]lways allow {}:", tool_name)?;
writeln!(
stderr,
"└────────────────────────────────────────────────────"
)?;
write!(stderr, "> ")?;
stderr.flush()?;
let mut response = String::new();
std::io::stdin().read_line(&mut response)?;
let response = response.trim().to_lowercase();
match response.as_str() {
"y" | "yes" => Ok(ConfirmationResult::Yes),
"n" | "no" | "" => Ok(ConfirmationResult::No),
"a" | "always" => Ok(ConfirmationResult::Always),
_ => {
writeln!(stderr, "Invalid response. Interpreting as 'no'.")?;
Ok(ConfirmationResult::No)
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConfirmationResult {
Yes,
No,
Always,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_allows_within_limit() {
let config = McpConfig::default();
let limiter = RateLimiter::new(&config);
for _ in 0..config.mcp.max_checks_per_minute {
assert!(limiter.check_check_limit().is_ok());
}
}
#[test]
fn test_rate_limiter_blocks_over_limit() {
let mut config = McpConfig::default();
config.mcp.max_checks_per_minute = 2;
let limiter = RateLimiter::new(&config);
assert!(limiter.check_check_limit().is_ok());
assert!(limiter.check_check_limit().is_ok());
assert!(limiter.check_check_limit().is_err()); }
#[test]
fn test_rate_limiter_install_limit() {
let mut config = McpConfig::default();
config.mcp.max_installs_per_minute = 1;
let limiter = RateLimiter::new(&config);
assert!(limiter.check_install_limit().is_ok());
assert!(limiter.check_install_limit().is_err());
}
#[test]
fn test_check_allowlist_no_lists() {
let config = McpConfig::default();
assert!(check_allowlist("git", &config).is_ok());
assert!(check_allowlist("anything", &config).is_ok());
}
#[test]
fn test_check_allowlist_with_denylist() {
let mut config = McpConfig::default();
config.mcp.denylist = Some(vec!["brew".to_string()]);
assert!(check_allowlist("git", &config).is_ok());
let result = check_allowlist("brew", &config);
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, -32002); }
#[test]
fn test_check_allowlist_with_allowlist() {
let mut config = McpConfig::default();
config.mcp.allowlist = Some(vec!["git".to_string(), "docker".to_string()]);
assert!(check_allowlist("git", &config).is_ok());
assert!(check_allowlist("docker", &config).is_ok());
let result = check_allowlist("vim", &config);
assert!(result.is_err());
assert_eq!(result.unwrap_err().code, -32003); }
}