use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::util::BoxFuture;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "decision", rename_all = "snake_case")]
pub enum PermissionDecision {
Allow {
#[serde(default, skip_serializing_if = "Option::is_none")]
updated_input: Option<serde_json::Value>,
},
Deny {
message: String,
#[serde(default)]
interrupt: bool,
},
}
impl PermissionDecision {
#[must_use]
pub fn allow() -> Self {
Self::Allow {
updated_input: None,
}
}
#[must_use]
pub fn allow_with_input(input: serde_json::Value) -> Self {
Self::Allow {
updated_input: Some(input),
}
}
#[must_use]
pub fn deny(message: impl Into<String>) -> Self {
Self::Deny {
message: message.into(),
interrupt: false,
}
}
#[must_use]
pub fn deny_and_interrupt(message: impl Into<String>) -> Self {
Self::Deny {
message: message.into(),
interrupt: true,
}
}
}
#[derive(Debug, Clone)]
pub struct PermissionContext {
pub tool_use_id: String,
pub session_id: String,
pub request_id: String,
pub suggestions: Vec<String>,
}
pub type CanUseToolCallback = Arc<
dyn Fn(&str, &serde_json::Value, PermissionContext) -> BoxFuture<PermissionDecision>
+ Send
+ Sync,
>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub(crate) struct ControlRequest {
pub request_id: String,
pub request: ControlRequestData,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum ControlRequestData {
PermissionRequest {
tool_name: String,
tool_input: serde_json::Value,
tool_use_id: String,
#[serde(default)]
suggestions: Vec<String>,
},
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub(crate) struct ControlResponse {
pub kind: String,
pub request_id: String,
pub result: ControlResponseResult,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum ControlResponseResult {
Allow {
#[serde(default, skip_serializing_if = "Option::is_none")]
updated_input: Option<serde_json::Value>,
},
Deny {
message: String,
#[serde(default)]
interrupt: bool,
},
}
impl From<PermissionDecision> for ControlResponseResult {
fn from(decision: PermissionDecision) -> Self {
match decision {
PermissionDecision::Allow { updated_input } => Self::Allow { updated_input },
PermissionDecision::Deny { message, interrupt } => Self::Deny { message, interrupt },
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub(crate) struct LegacyPermissionRequest {
pub request_id: String,
pub tool_name: String,
pub action: String,
#[serde(default)]
pub details: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn permission_decision_allow_round_trip() {
let d = PermissionDecision::allow();
let json = serde_json::to_string(&d).unwrap();
let decoded: PermissionDecision = serde_json::from_str(&json).unwrap();
assert_eq!(d, decoded);
}
#[test]
fn permission_decision_allow_with_input() {
let d = PermissionDecision::allow_with_input(serde_json::json!({"modified": true}));
let json = serde_json::to_string(&d).unwrap();
assert!(json.contains("modified"));
let decoded: PermissionDecision = serde_json::from_str(&json).unwrap();
assert_eq!(d, decoded);
}
#[test]
fn permission_decision_deny_round_trip() {
let d = PermissionDecision::deny("not allowed");
let json = serde_json::to_string(&d).unwrap();
let decoded: PermissionDecision = serde_json::from_str(&json).unwrap();
assert_eq!(d, decoded);
}
#[test]
fn permission_decision_deny_interrupt() {
let d = PermissionDecision::deny_and_interrupt("abort");
if let PermissionDecision::Deny { interrupt, .. } = &d {
assert!(interrupt);
} else {
panic!("expected Deny");
}
}
#[test]
fn control_request_permission_round_trip() {
let req = ControlRequest {
request_id: "req-1".into(),
request: ControlRequestData::PermissionRequest {
tool_name: "bash".into(),
tool_input: serde_json::json!({"command": "rm -rf /"}),
tool_use_id: "tu-1".into(),
suggestions: vec!["allow_once".into()],
},
};
let json = serde_json::to_string(&req).unwrap();
let decoded: ControlRequest = serde_json::from_str(&json).unwrap();
assert_eq!(req, decoded);
}
#[test]
fn control_response_allow_round_trip() {
let resp = ControlResponse {
kind: "permission_response".into(),
request_id: "req-1".into(),
result: ControlResponseResult::Allow {
updated_input: None,
},
};
let json = serde_json::to_string(&resp).unwrap();
let decoded: ControlResponse = serde_json::from_str(&json).unwrap();
assert_eq!(resp, decoded);
}
#[test]
fn control_response_deny_round_trip() {
let resp = ControlResponse {
kind: "permission_response".into(),
request_id: "req-1".into(),
result: ControlResponseResult::Deny {
message: "dangerous".into(),
interrupt: true,
},
};
let json = serde_json::to_string(&resp).unwrap();
let decoded: ControlResponse = serde_json::from_str(&json).unwrap();
assert_eq!(resp, decoded);
}
#[test]
fn permission_decision_to_control_response_result() {
let allow = PermissionDecision::allow();
let result: ControlResponseResult = allow.into();
assert!(matches!(result, ControlResponseResult::Allow { .. }));
let deny = PermissionDecision::deny("no");
let result: ControlResponseResult = deny.into();
assert!(matches!(result, ControlResponseResult::Deny { .. }));
}
#[test]
fn legacy_permission_request_round_trip() {
let req = LegacyPermissionRequest {
request_id: "lr-1".into(),
tool_name: "read_file".into(),
action: "read /etc/passwd".into(),
details: Some("attempting to read sensitive file".into()),
};
let json = serde_json::to_string(&req).unwrap();
let decoded: LegacyPermissionRequest = serde_json::from_str(&json).unwrap();
assert_eq!(req, decoded);
}
}