use anyhow::anyhow;
#[cfg(test)]
use serde_json::Map;
use serde_json::Value;
use std::sync::Arc;
#[cfg(test)]
use std::sync::Mutex;
use thiserror::Error;
#[cfg(test)]
use vtcode_core::config::constants::tools;
use vtcode_core::tools::{
RiskLevel, SafetyContext, SafetyDecision, SafetyError as GatewaySafetyError, SafetyGateway,
SafetyGatewayConfig, ToolInvocationId, WorkspaceTrust,
};
#[derive(Debug, Error)]
pub(crate) enum SafetyError {
#[error("Per-turn tool limit reached (max: {max}). Wait or adjust config.")]
TurnLimitReached { max: usize },
#[error("Session tool limit reached (max: {max}). End turn or reduce tool calls.")]
SessionLimitReached { max: usize },
#[error("Rate limit exceeded: {current} calls/{window} (max: {max})")]
RateLimitExceeded {
current: usize,
max: usize,
window: &'static str,
},
#[error(transparent)]
Other(#[from] anyhow::Error),
}
pub(crate) struct ToolCallSafetyValidator {
safety_gateway: Arc<SafetyGateway>,
gateway_ctx: SafetyContext,
#[cfg(test)]
test_rate_limits: Mutex<TestRateLimits>,
}
#[cfg(test)]
struct TestRateLimits {
per_second: usize,
per_minute: Option<usize>,
}
impl ToolCallSafetyValidator {
pub(crate) fn new() -> Self {
let gateway_config = SafetyGatewayConfig {
max_per_turn: 10,
max_per_session: 100,
plan_mode_active: false,
workspace_trust: WorkspaceTrust::Trusted,
approval_risk_threshold: RiskLevel::Medium,
enforce_rate_limits: false,
..SafetyGatewayConfig::default()
};
#[cfg(test)]
let test_rate_limits = TestRateLimits {
per_second: gateway_config.rate_limit_per_second,
per_minute: gateway_config.rate_limit_per_minute,
};
Self {
safety_gateway: Arc::new(SafetyGateway::with_config(gateway_config)),
gateway_ctx: SafetyContext::new("runloop-safety-validator"),
#[cfg(test)]
test_rate_limits: Mutex::new(test_rate_limits),
}
}
pub(crate) fn with_gateway(safety_gateway: Arc<SafetyGateway>) -> Self {
#[cfg(test)]
let test_rate_limits = TestRateLimits {
per_second: SafetyGatewayConfig::default().rate_limit_per_second,
per_minute: SafetyGatewayConfig::default().rate_limit_per_minute,
};
Self {
safety_gateway,
gateway_ctx: SafetyContext::new("runloop-safety-validator"),
#[cfg(test)]
test_rate_limits: Mutex::new(test_rate_limits),
}
}
pub(crate) fn start_turn(&self) {
self.safety_gateway.start_turn();
}
pub(crate) fn set_limits(&self, max_per_turn: usize, max_per_session: usize) {
self.safety_gateway
.set_limits(max_per_turn, max_per_session);
}
pub(crate) fn increase_session_limit(&self, increment: usize) {
self.safety_gateway.increase_session_limit(increment);
}
#[cfg(test)]
pub fn set_rate_limit_per_second(&self, limit: usize) {
if limit > 0 {
let mut test_rate_limits = self
.test_rate_limits
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
test_rate_limits.per_second = limit;
self.safety_gateway
.set_rate_limits(test_rate_limits.per_second, test_rate_limits.per_minute);
}
}
#[cfg(test)]
pub fn set_rate_limit_enforcement(&self, enabled: bool) {
self.safety_gateway.set_rate_limit_enforcement(enabled);
}
pub(crate) async fn validate_call(
&self,
tool_name: &str,
args: &Value,
) -> std::result::Result<(), SafetyError> {
self.validate_call_with_invocation_id(tool_name, args, ToolInvocationId::new())
.await
}
pub(crate) async fn validate_call_with_invocation_id(
&self,
tool_name: &str,
args: &Value,
invocation_id: ToolInvocationId,
) -> std::result::Result<(), SafetyError> {
let result = self
.safety_gateway
.check_and_record_with_id(&self.gateway_ctx, tool_name, args, Some(invocation_id))
.await;
match result.decision {
SafetyDecision::Allow | SafetyDecision::NeedsApproval(_) => Ok(()),
SafetyDecision::Deny(reason) => Err(map_gateway_violation(result.violation, &reason)),
}
}
#[cfg(test)]
pub fn is_destructive(&self, tool_name: &str) -> bool {
let normalized = tool_name.trim().to_ascii_lowercase();
vtcode_core::tools::tool_intent::classify_tool_intent(
normalized.as_str(),
&Value::Object(Map::new()),
)
.destructive
}
}
fn map_gateway_violation(violation: Option<GatewaySafetyError>, reason: &str) -> SafetyError {
match violation {
Some(GatewaySafetyError::TurnLimitReached { max }) => SafetyError::TurnLimitReached { max },
Some(GatewaySafetyError::SessionLimitReached { max }) => {
SafetyError::SessionLimitReached { max }
}
Some(GatewaySafetyError::RateLimitExceeded {
current,
max,
window,
}) => SafetyError::RateLimitExceeded {
current,
max,
window,
},
Some(err) => SafetyError::Other(anyhow!(err.to_string())),
None => SafetyError::Other(anyhow!(reason.to_string())),
}
}
impl Default for ToolCallSafetyValidator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_destructive_tool_detection() {
let validator = ToolCallSafetyValidator::new();
assert!(validator.is_destructive("delete_file"));
assert!(validator.is_destructive("edit_file"));
assert!(!validator.is_destructive("read_file"));
assert!(!validator.is_destructive(tools::GREP_FILE));
}
#[tokio::test]
async fn test_rate_limiting() {
let validator = ToolCallSafetyValidator::new();
validator.set_rate_limit_per_second(2);
validator.set_rate_limit_enforcement(true);
validator.start_turn();
assert!(
validator
.validate_call("read_file", &json!({}))
.await
.is_ok()
);
assert!(
validator
.validate_call("read_file", &json!({}))
.await
.is_ok()
);
assert!(matches!(
validator.validate_call("read_file", &json!({})).await,
Err(SafetyError::RateLimitExceeded { .. })
));
}
#[tokio::test]
async fn test_validation_allows_safe_and_destructive_tools() {
let validator = ToolCallSafetyValidator::new();
validator.start_turn();
assert!(
validator
.validate_call("read_file", &json!({}))
.await
.is_ok()
);
assert!(
validator
.validate_call("delete_file", &json!({}))
.await
.is_ok()
);
}
#[tokio::test]
async fn test_turn_and_session_limits() {
let validator = ToolCallSafetyValidator::new();
validator.set_limits(2, 3);
validator.start_turn();
assert!(
validator
.validate_call("read_file", &json!({}))
.await
.is_ok()
);
assert!(
validator
.validate_call("read_file", &json!({}))
.await
.is_ok()
);
assert!(
validator
.validate_call("read_file", &json!({}))
.await
.is_err()
);
validator.start_turn();
assert!(
validator
.validate_call("read_file", &json!({}))
.await
.is_ok()
);
assert!(
validator
.validate_call("read_file", &json!({}))
.await
.is_err()
);
}
}