use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::jsonrpc::JsonRpcId;
use crate::transport::ReverseRequestHandler;
use crate::wire;
pub(crate) type ToolInputCache = Arc<Mutex<HashMap<String, Value>>>;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "decision", rename_all = "snake_case")]
pub enum PermissionDecision {
Allow {
#[serde(default, skip_serializing_if = "Option::is_none")]
updated_input: Option<Value>,
},
Deny {
message: String,
#[serde(default)]
interrupt: bool,
},
}
impl PermissionDecision {
pub fn allow() -> Self {
PermissionDecision::Allow { updated_input: None }
}
pub fn allow_with_input(input: Value) -> Self {
PermissionDecision::Allow { updated_input: Some(input) }
}
pub fn deny(message: impl Into<String>) -> Self {
PermissionDecision::Deny { message: message.into(), interrupt: false }
}
pub fn deny_and_interrupt(message: impl Into<String>) -> Self {
PermissionDecision::Deny { message: message.into(), interrupt: true }
}
}
#[derive(Debug, Clone)]
pub struct PermissionContext {
pub tool_use_id: String,
pub session_id: String,
pub request_id: String,
pub suggestions: Vec<String>,
pub permission_options: Vec<PermissionOptionInfo>,
pub tool_kind: Option<String>,
pub tool_locations: Vec<ToolLocationInfo>,
}
#[derive(Debug, Clone)]
pub struct PermissionOptionInfo {
pub id: String,
pub label: String,
pub description: String,
}
#[derive(Debug, Clone)]
pub struct ToolLocationInfo {
pub uri: String,
pub range: Option<Value>,
}
pub type CanUseToolCallback = Arc<
dyn Fn(&str, &Value, PermissionContext) -> Pin<Box<dyn Future<Output = PermissionDecision> + Send>>
+ Send
+ Sync,
>;
pub(crate) struct PermissionHandler {
callback: CanUseToolCallback,
input_cache: Option<ToolInputCache>,
}
impl PermissionHandler {
pub fn new(callback: CanUseToolCallback, input_cache: Option<ToolInputCache>) -> Self {
Self { callback, input_cache }
}
}
#[async_trait]
impl ReverseRequestHandler for PermissionHandler {
async fn handle_permission_request(
&self,
id: JsonRpcId,
params: Value,
) -> Value {
let wire_params: wire::RequestPermissionParams = match serde_json::from_value(params.clone()) {
Ok(p) => p,
Err(e) => {
tracing::warn!(
error = %e,
raw_params = %serde_json::to_string_pretty(¶ms).unwrap_or_default(),
"Failed to parse permission request params"
);
return serde_json::to_value(wire::RequestPermissionResponse {
outcome: wire::RequestPermissionOutcome {
outcome: "cancelled".to_string(),
option_id: None,
updated_input: None,
interrupt: None,
},
})
.unwrap_or_default();
}
};
tracing::debug!(
tool = %wire_params.tool_call.kind,
options = wire_params.options.len(),
"Permission request parsed successfully"
);
let tool_name = &wire_params.tool_call.kind;
let tool_input = &wire_params.tool_call.raw_input;
if let Some(cache) = &self.input_cache {
if let Ok(mut map) = cache.lock() {
map.insert(
wire_params.tool_call.tool_call_id.clone(),
wire_params.tool_call.raw_input.clone(),
);
}
}
let context = PermissionContext {
tool_use_id: wire_params.tool_call.tool_call_id.clone(),
session_id: wire_params.session_id.clone(),
request_id: id.to_string(),
suggestions: Vec::new(),
permission_options: wire_params
.options
.iter()
.map(|opt| PermissionOptionInfo {
id: opt.id.clone(),
label: opt.label.clone(),
description: opt.description.clone(),
})
.collect(),
tool_kind: Some(wire_params.tool_call.kind.clone()),
tool_locations: wire_params
.tool_call
.locations
.iter()
.map(|loc| ToolLocationInfo {
uri: loc.uri.clone(),
range: loc.range.clone(),
})
.collect(),
};
let decision = (self.callback)(tool_name, tool_input, context).await;
match decision {
PermissionDecision::Allow { updated_input } => {
let option_id = wire_params
.options
.iter()
.find(|opt| {
let id_lower = opt.id.to_lowercase();
id_lower.contains("proceed")
|| id_lower.contains("allow")
|| opt.label.to_lowercase().contains("allow")
})
.or_else(|| wire_params.options.first())
.map(|opt| opt.id.clone());
match option_id {
Some(id) => serde_json::to_value(wire::RequestPermissionResponse {
outcome: wire::RequestPermissionOutcome {
outcome: "selected".to_string(),
option_id: Some(id),
updated_input,
interrupt: None,
},
})
.unwrap_or_default(),
None => serde_json::to_value(wire::RequestPermissionResponse {
outcome: wire::RequestPermissionOutcome {
outcome: "cancelled".to_string(),
option_id: None,
updated_input: None,
interrupt: None,
},
})
.unwrap_or_default(),
}
}
PermissionDecision::Deny { interrupt, .. } => {
serde_json::to_value(wire::RequestPermissionResponse {
outcome: wire::RequestPermissionOutcome {
outcome: "cancelled".to_string(),
option_id: None,
updated_input: None,
interrupt: if interrupt { Some(true) } else { None },
},
})
.unwrap_or_default()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_permission_decision_allow() {
let d = PermissionDecision::allow();
assert!(matches!(d, PermissionDecision::Allow { updated_input: None }));
}
#[test]
fn test_permission_decision_allow_with_input() {
let d = PermissionDecision::allow_with_input(json!({"key": "val"}));
assert!(matches!(d, PermissionDecision::Allow { updated_input: Some(_) }));
}
#[test]
fn test_permission_decision_deny() {
let d = PermissionDecision::deny("no");
assert!(matches!(d, PermissionDecision::Deny { interrupt: false, .. }));
}
#[test]
fn test_permission_decision_deny_interrupt() {
let d = PermissionDecision::deny_and_interrupt("stop");
assert!(matches!(d, PermissionDecision::Deny { interrupt: true, .. }));
}
#[tokio::test]
async fn test_permission_handler_allow() {
let callback: CanUseToolCallback = Arc::new(|_name, _input, _ctx| {
Box::pin(async { PermissionDecision::allow() })
});
let handler = PermissionHandler::new(callback, None);
let params = json!({
"sessionId": "sess-1",
"toolCall": {
"toolCallId": "tc-1",
"title": "Edit",
"kind": "edit",
"rawInput": {},
"locations": []
},
"options": [
{ "id": "ProceedOnce", "label": "Allow", "description": "Allow once" },
{ "id": "Cancel", "label": "Cancel", "description": "Cancel" }
]
});
let result = handler
.handle_permission_request(crate::jsonrpc::JsonRpcId::Number(1), params)
.await;
let outcome = result.get("outcome").expect("must have outcome");
assert_eq!(outcome.get("outcome").and_then(|v| v.as_str()), Some("selected"));
assert_eq!(outcome.get("optionId").and_then(|v| v.as_str()), Some("ProceedOnce"));
}
#[tokio::test]
async fn test_permission_handler_deny() {
let callback: CanUseToolCallback = Arc::new(|_name, _input, _ctx| {
Box::pin(async { PermissionDecision::deny("not allowed") })
});
let handler = PermissionHandler::new(callback, None);
let params = json!({
"sessionId": "sess-1",
"toolCall": {
"toolCallId": "tc-1",
"title": "Execute",
"kind": "exec",
"rawInput": {},
"locations": []
},
"options": []
});
let result = handler
.handle_permission_request(crate::jsonrpc::JsonRpcId::Number(2), params)
.await;
let outcome = result.get("outcome").expect("must have outcome");
assert_eq!(outcome.get("outcome").and_then(|v| v.as_str()), Some("cancelled"));
}
#[tokio::test]
async fn test_permission_handler_parse_error() {
let callback: CanUseToolCallback = Arc::new(|_name, _input, _ctx| {
Box::pin(async { PermissionDecision::allow() })
});
let handler = PermissionHandler::new(callback, None);
let result = handler
.handle_permission_request(
crate::jsonrpc::JsonRpcId::Number(3),
json!({"invalid": true}),
)
.await;
let outcome = result.get("outcome").expect("must have outcome");
assert_eq!(outcome.get("outcome").and_then(|v| v.as_str()), Some("cancelled"));
}
#[tokio::test]
async fn test_permission_handler_allow_with_updated_input() {
let modified_input = json!({"command": "ls -la /safe/path"});
let modified_clone = modified_input.clone();
let callback: CanUseToolCallback = Arc::new(move |_name, _input, _ctx| {
let input = modified_clone.clone();
Box::pin(async move { PermissionDecision::allow_with_input(input) })
});
let handler = PermissionHandler::new(callback, None);
let params = json!({
"sessionId": "sess-1",
"toolCall": {
"toolCallId": "tc-upd",
"title": "Shell",
"kind": "shell",
"rawInput": {"command": "rm -rf /"},
"locations": []
},
"options": [
{ "id": "ProceedOnce", "label": "Allow", "description": "Allow once" }
]
});
let result = handler
.handle_permission_request(crate::jsonrpc::JsonRpcId::Number(42), params)
.await;
let outcome = result.get("outcome").expect("must have outcome");
assert_eq!(outcome.get("outcome").and_then(|v| v.as_str()), Some("selected"));
assert_eq!(
outcome.get("updatedInput"),
Some(&modified_input),
"updated_input must be forwarded in the response"
);
}
#[tokio::test]
async fn test_permission_handler_allow_with_empty_options_returns_cancelled() {
let callback: CanUseToolCallback = Arc::new(|_name, _input, _ctx| {
Box::pin(async { PermissionDecision::allow() })
});
let handler = PermissionHandler::new(callback, None);
let params = json!({
"sessionId": "sess-1",
"toolCall": {
"toolCallId": "tc-empty",
"title": "Edit",
"kind": "edit",
"rawInput": {},
"locations": []
},
"options": [] });
let result = handler
.handle_permission_request(crate::jsonrpc::JsonRpcId::Number(99), params)
.await;
let outcome = result.get("outcome").expect("must have outcome");
assert_eq!(
outcome.get("outcome").and_then(|v| v.as_str()),
Some("cancelled"),
"empty options with Allow decision must respond with 'cancelled'"
);
assert!(
outcome.get("optionId").is_none(),
"optionId must be absent in the cancelled response"
);
}
}