use crate::config::constants::tools;
use crate::utils::ansi_codes::{FG_GREEN, FG_MAGENTA, FG_RED, FG_YELLOW};
use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum RiskLevel {
Low,
Medium,
High,
Critical,
}
impl RiskLevel {
pub fn as_str(self) -> &'static str {
match self {
Self::Low => "low",
Self::Medium => "medium",
Self::High => "high",
Self::Critical => "critical",
}
}
pub fn color_code(self) -> &'static str {
match self {
Self::Low => FG_GREEN,
Self::Medium => FG_YELLOW,
Self::High => FG_RED,
Self::Critical => FG_MAGENTA,
}
}
}
impl std::fmt::Display for RiskLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolSource {
Internal,
Mcp,
Acp,
External,
}
impl ToolSource {
pub fn risk_multiplier(self) -> f32 {
match self {
Self::Internal => 1.0,
Self::Mcp => 1.5,
Self::Acp => 1.2,
Self::External => 2.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum WorkspaceTrust {
Untrusted,
Partial,
Trusted,
FullAuto,
}
impl WorkspaceTrust {
pub fn risk_reduction(self) -> f32 {
match self {
Self::Untrusted => 1.0,
Self::Partial => 0.8,
Self::Trusted => 0.6,
Self::FullAuto => 0.3,
}
}
}
#[derive(Debug, Clone)]
pub struct ToolRiskContext {
pub tool_name: String,
pub source: ToolSource,
pub workspace_trust: WorkspaceTrust,
pub recent_approvals: usize,
pub command_args: Vec<String>,
pub is_write: bool,
pub is_destructive: bool,
pub accesses_network: bool,
}
impl ToolRiskContext {
pub fn new(tool_name: String, source: ToolSource, workspace_trust: WorkspaceTrust) -> Self {
Self {
tool_name,
source,
workspace_trust,
recent_approvals: 0,
command_args: Vec::new(),
is_write: false,
is_destructive: false,
accesses_network: false,
}
}
pub fn with_args(mut self, args: Vec<String>) -> Self {
self.command_args = args;
self
}
pub fn as_write(mut self) -> Self {
self.is_write = true;
self
}
pub fn as_destructive(mut self) -> Self {
self.is_destructive = true;
self
}
pub fn accesses_network(mut self) -> Self {
self.accesses_network = true;
self
}
}
pub struct ToolRiskScorer;
impl ToolRiskScorer {
pub fn calculate_risk(ctx: &ToolRiskContext) -> RiskLevel {
let mut base_score = Self::base_risk_for_tool(&ctx.tool_name);
if ctx.is_destructive {
base_score += 30;
}
if ctx.is_write {
base_score += 15;
}
if ctx.accesses_network {
base_score += 10;
}
base_score = (base_score as f32 * ctx.source.risk_multiplier()) as u32;
base_score = (base_score as f32 * ctx.workspace_trust.risk_reduction()) as u32;
let approval_reduction = ctx.recent_approvals.min(3) as u32 * 5;
base_score = base_score.saturating_sub(approval_reduction);
match base_score {
0..=25 => RiskLevel::Low,
26..=50 => RiskLevel::Medium,
51..=75 => RiskLevel::High,
_ => RiskLevel::Critical,
}
}
pub fn requires_justification(risk: RiskLevel, threshold: RiskLevel) -> bool {
risk >= threshold
}
fn base_risk_for_tool(tool_name: &str) -> u32 {
match tool_name {
tools::READ_FILE | tools::UNIFIED_SEARCH => 0,
"file_info" | "status" | "logs" => 5,
tools::WRITE_FILE | tools::EDIT_FILE | tools::CREATE_FILE => 20,
tools::APPLY_PATCH | tools::DELETE_FILE => 25,
tools::CREATE_PTY_SESSION
| tools::RUN_PTY_CMD
| tools::SEND_PTY_INPUT
| tools::UNIFIED_EXEC => 35,
"web_search" | "fetch_url" | "unified_search:web" => 40,
_ if tool_name.starts_with("mcp_") => 30,
_ => 35,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_risk_level_ordering() {
assert!(RiskLevel::Low < RiskLevel::Medium);
assert!(RiskLevel::Medium < RiskLevel::High);
assert!(RiskLevel::High < RiskLevel::Critical);
}
#[test]
fn test_risk_calculation() {
let ctx = ToolRiskContext::new(
tools::READ_FILE.to_string(),
ToolSource::Internal,
WorkspaceTrust::Trusted,
);
let risk = ToolRiskScorer::calculate_risk(&ctx);
assert_eq!(risk, RiskLevel::Low);
let ctx = ToolRiskContext::new(
tools::WRITE_FILE.to_string(),
ToolSource::External,
WorkspaceTrust::Untrusted,
)
.as_write();
let risk = ToolRiskScorer::calculate_risk(&ctx);
assert!(risk >= RiskLevel::High);
}
#[test]
fn test_approval_history_reduces_risk() {
let mut ctx = ToolRiskContext::new(
tools::RUN_PTY_CMD.to_string(),
ToolSource::Internal,
WorkspaceTrust::Untrusted,
);
let risk_before = ToolRiskScorer::calculate_risk(&ctx);
ctx.recent_approvals = 3;
let risk_after = ToolRiskScorer::calculate_risk(&ctx);
assert!(risk_after <= risk_before);
}
#[test]
fn test_source_multiplier() {
let base = ToolRiskContext::new(
"mcp_tool".to_string(),
ToolSource::Internal,
WorkspaceTrust::Trusted,
);
let base_risk = ToolRiskScorer::calculate_risk(&base);
let mcp = ToolRiskContext::new(
"mcp_tool".to_string(),
ToolSource::Mcp,
WorkspaceTrust::Trusted,
);
let mcp_risk = ToolRiskScorer::calculate_risk(&mcp);
assert!(mcp_risk > base_risk || mcp_risk == RiskLevel::Critical);
}
#[test]
fn test_requires_justification() {
assert!(ToolRiskScorer::requires_justification(
RiskLevel::High,
RiskLevel::High
));
assert!(!ToolRiskScorer::requires_justification(
RiskLevel::Medium,
RiskLevel::High
));
}
}