use std::collections::HashSet;
use std::path::Path;
use std::sync::{Arc, Mutex};
use agent_client_protocol_schema::{PermissionOptionId, PermissionOptionKind};
use serde::{Deserialize, Serialize};
use crate::tool::SafetyClass;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum PolicyDecision {
Allow,
Deny,
Ask(Ask),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Ask {
pub options: Vec<AskOption>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AskOption {
pub id: PermissionOptionId,
pub name: String,
pub kind: PermissionOptionKind,
pub allows: bool,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RecordedOutcome {
Selected {
option_id: PermissionOptionId,
allows: bool,
},
Cancelled,
}
#[non_exhaustive]
pub struct PolicyCtx<'a> {
pub tool_name: &'a str,
pub safety_hint: SafetyClass,
pub args: &'a serde_json::Value,
pub cwd: &'a Path,
}
impl<'a> PolicyCtx<'a> {
pub fn new(
tool_name: &'a str,
safety_hint: SafetyClass,
args: &'a serde_json::Value,
cwd: &'a Path,
) -> Self {
Self {
tool_name,
safety_hint,
args,
cwd,
}
}
}
pub trait SandboxPolicy: Send + Sync {
fn classify(&self, ctx: PolicyCtx<'_>) -> PolicyDecision;
fn record(&self, ctx: PolicyCtx<'_>, outcome: RecordedOutcome);
}
pub struct OpenPolicy;
impl SandboxPolicy for OpenPolicy {
fn classify(&self, _ctx: PolicyCtx<'_>) -> PolicyDecision {
PolicyDecision::Allow
}
fn record(&self, _ctx: PolicyCtx<'_>, _outcome: RecordedOutcome) {}
}
pub struct ReadOnlyPolicy;
impl SandboxPolicy for ReadOnlyPolicy {
fn classify(&self, ctx: PolicyCtx<'_>) -> PolicyDecision {
match ctx.safety_hint {
SafetyClass::ReadOnly => PolicyDecision::Allow,
_ => PolicyDecision::Deny,
}
}
fn record(&self, _ctx: PolicyCtx<'_>, _outcome: RecordedOutcome) {}
}
pub struct DenyAllPolicy;
impl SandboxPolicy for DenyAllPolicy {
fn classify(&self, _ctx: PolicyCtx<'_>) -> PolicyDecision {
PolicyDecision::Deny
}
fn record(&self, _ctx: PolicyCtx<'_>, _outcome: RecordedOutcome) {}
}
pub struct AskWritesPolicy {
always_allow: Mutex<HashSet<String>>,
}
impl AskWritesPolicy {
pub fn new() -> Self {
Self {
always_allow: Mutex::new(HashSet::new()),
}
}
}
impl Default for AskWritesPolicy {
fn default() -> Self {
Self::new()
}
}
impl SandboxPolicy for AskWritesPolicy {
fn classify(&self, ctx: PolicyCtx<'_>) -> PolicyDecision {
if matches!(ctx.safety_hint, SafetyClass::ReadOnly) {
return PolicyDecision::Allow;
}
if let Ok(table) = self.always_allow.lock()
&& table.contains(ctx.tool_name)
{
return PolicyDecision::Allow;
}
PolicyDecision::Ask(default_ask_options(ctx.tool_name))
}
fn record(&self, ctx: PolicyCtx<'_>, outcome: RecordedOutcome) {
let RecordedOutcome::Selected { option_id, allows } = outcome else {
return;
};
if !allows {
return;
}
if option_id.0.as_ref() != ALLOW_ALWAYS_ID {
return;
}
if let Ok(mut table) = self.always_allow.lock() {
table.insert(ctx.tool_name.to_string());
}
}
}
pub struct NonInteractivePolicy {
inner: Arc<dyn SandboxPolicy>,
}
impl NonInteractivePolicy {
pub fn new(inner: Arc<dyn SandboxPolicy>) -> Self {
Self { inner }
}
}
impl SandboxPolicy for NonInteractivePolicy {
fn classify(&self, ctx: PolicyCtx<'_>) -> PolicyDecision {
match self.inner.classify(ctx) {
PolicyDecision::Ask(_) => PolicyDecision::Deny,
other => other,
}
}
fn record(&self, _ctx: PolicyCtx<'_>, _outcome: RecordedOutcome) {}
}
#[derive(Clone)]
pub struct PolicyMode {
pub id: String,
pub name: String,
pub description: Option<String>,
pub policy: Arc<dyn SandboxPolicy>,
}
impl std::fmt::Debug for PolicyMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PolicyMode")
.field("id", &self.id)
.field("name", &self.name)
.field("description", &self.description)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub struct ModeCatalog {
modes: Vec<PolicyMode>,
current: String,
}
impl ModeCatalog {
#[must_use]
pub fn new(modes: Vec<PolicyMode>, current: impl Into<String>) -> Option<Self> {
let current = current.into();
if modes.is_empty() || !modes.iter().any(|m| m.id == current) {
return None;
}
Some(Self { modes, current })
}
#[must_use]
pub fn current_id(&self) -> &str {
&self.current
}
#[must_use]
pub fn current_policy(&self) -> Arc<dyn SandboxPolicy> {
self.modes
.iter()
.find(|m| m.id == self.current)
.map(|m| m.policy.clone())
.expect("ModeCatalog current id must always resolve to a mode")
}
#[must_use]
pub fn modes(&self) -> &[PolicyMode] {
&self.modes
}
pub fn set_current(&mut self, id: &str) -> bool {
if self.modes.iter().any(|m| m.id == id) {
self.current = id.to_string();
true
} else {
false
}
}
}
const ALLOW_ONCE_ID: &str = "allow_once";
const ALLOW_ALWAYS_ID: &str = "allow_always";
const REJECT_ONCE_ID: &str = "reject_once";
fn default_ask_options(tool_name: &str) -> Ask {
let options = vec![
AskOption {
id: PermissionOptionId::new(ALLOW_ONCE_ID),
name: format!("Allow `{tool_name}` once"),
kind: PermissionOptionKind::AllowOnce,
allows: true,
},
AskOption {
id: PermissionOptionId::new(ALLOW_ALWAYS_ID),
name: format!("Allow `{tool_name}` always"),
kind: PermissionOptionKind::AllowAlways,
allows: true,
},
AskOption {
id: PermissionOptionId::new(REJECT_ONCE_ID),
name: "Reject".to_string(),
kind: PermissionOptionKind::RejectOnce,
allows: false,
},
];
Ask { options }
}
#[cfg(test)]
mod tests;