use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use opi_agent::extension::{Extension, ExtensionHookResult, ExtensionRegistry};
use opi_agent::hooks::AgentHooks;
use opi_agent::loop_types::{AgentError, AgentLoopConfig};
use opi_agent::message::AgentMessage;
use opi_agent::tool::{ExecutionMode, Tool, ToolError, ToolResult};
use opi_ai::message::{OutputContent, ToolDef};
use opi_ai::test_support::{MockProvider, text_response, tool_call_response};
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone)]
enum PermissionPolicy {
AllowAll,
DenyAll,
DenyList(Vec<String>),
AllowList(Vec<String>),
}
#[derive(Debug, Clone)]
struct AuditEntry {
tool_name: String,
decision: String,
reason: Option<String>,
}
struct PermissionGateExtension {
policy: PermissionPolicy,
audit_log: Arc<Mutex<Vec<AuditEntry>>>,
events_received: Arc<Mutex<Vec<String>>>,
}
impl PermissionGateExtension {
fn new(policy: PermissionPolicy) -> Self {
Self {
policy,
audit_log: Arc::new(Mutex::new(Vec::new())),
events_received: Arc::new(Mutex::new(Vec::new())),
}
}
fn evaluate(&self, tool_name: &str) -> ExtensionHookResult {
let decision = match &self.policy {
PermissionPolicy::AllowAll => "allowed".to_string(),
PermissionPolicy::DenyAll => "denied".to_string(),
PermissionPolicy::DenyList(list) => if list.iter().any(|t| t == tool_name) {
"denied"
} else {
"allowed"
}
.to_string(),
PermissionPolicy::AllowList(list) => if list.iter().any(|t| t == tool_name) {
"allowed"
} else {
"denied"
}
.to_string(),
};
let reason = if decision == "denied" {
Some(format!(
"permission gate denied '{}' based on {:?} policy",
tool_name, self.policy
))
} else {
None
};
self.audit_log.lock().unwrap().push(AuditEntry {
tool_name: tool_name.to_string(),
decision: decision.clone(),
reason: reason.clone(),
});
match decision.as_str() {
"denied" => ExtensionHookResult::Block {
reason: reason.unwrap(),
},
_ => ExtensionHookResult::Continue,
}
}
}
impl Extension for PermissionGateExtension {
fn name(&self) -> &str {
"permission-gate"
}
fn on_before_tool_call(
&self,
tool_name: &str,
_args: &serde_json::Value,
) -> Pin<Box<dyn Future<Output = ExtensionHookResult> + Send>> {
let result = self.evaluate(tool_name);
Box::pin(async move { result })
}
fn on_event(&self, event: &opi_agent::event::AgentEvent) {
let label = match event {
opi_agent::event::AgentEvent::AgentStart => "AgentStart".to_string(),
opi_agent::event::AgentEvent::AgentEnd { .. } => "AgentEnd".to_string(),
opi_agent::event::AgentEvent::TurnStart => "TurnStart".to_string(),
opi_agent::event::AgentEvent::ToolExecutionStart { tool_name, .. } => {
format!("ToolExecutionStart({tool_name})")
}
opi_agent::event::AgentEvent::ToolExecutionEnd { tool_name, .. } => {
format!("ToolExecutionEnd({tool_name})")
}
_ => "Other".to_string(),
};
self.events_received.lock().unwrap().push(label);
}
fn serialize_state(
&self,
) -> Result<Option<serde_json::Value>, opi_agent::extension::ExtensionError> {
let log = self.audit_log.lock().unwrap();
let entries: Vec<serde_json::Value> = log
.iter()
.map(|e| {
serde_json::json!({
"tool_name": e.tool_name,
"decision": e.decision,
"reason": e.reason,
})
})
.collect();
Ok(Some(serde_json::json!({ "audit_log": entries })))
}
fn restore_state(
&self,
state: serde_json::Value,
) -> Result<(), opi_agent::extension::ExtensionError> {
if let Some(entries) = state["audit_log"].as_array() {
let mut log = self.audit_log.lock().unwrap();
log.clear();
for entry in entries {
log.push(AuditEntry {
tool_name: entry["tool_name"].as_str().unwrap_or("").to_string(),
decision: entry["decision"].as_str().unwrap_or("").to_string(),
reason: entry["reason"].as_str().map(|s| s.to_string()),
});
}
}
Ok(())
}
}
struct DummyTool {
name: String,
}
impl DummyTool {
fn new(name: &str) -> Self {
Self {
name: name.to_string(),
}
}
}
impl Tool for DummyTool {
fn definition(&self) -> ToolDef {
serde_json::from_value(serde_json::json!({
"name": self.name,
"description": format!("{} tool", self.name),
"input_schema": { "type": "object", "properties": {} }
}))
.unwrap()
}
fn execute(
&self,
_call_id: &str,
_arguments: serde_json::Value,
_signal: CancellationToken,
_on_update: Option<opi_agent::tool::UpdateCallback>,
) -> Pin<Box<dyn Future<Output = Result<ToolResult, ToolError>> + Send>> {
Box::pin(async {
Ok(ToolResult {
content: vec![OutputContent::Text { text: "ok".into() }],
details: None,
is_error: false,
terminate: false,
})
})
}
fn execution_mode(&self) -> ExecutionMode {
ExecutionMode::Parallel
}
}
struct TestHooks;
impl AgentHooks for TestHooks {
fn convert_to_llm(
&self,
messages: &[AgentMessage],
) -> Result<Vec<opi_ai::message::Message>, AgentError> {
Ok(messages
.iter()
.filter_map(|m| match m {
AgentMessage::Llm(msg) => Some(msg.clone()),
_ => None,
})
.collect())
}
}
fn extract_tool_result_text(messages: &[AgentMessage]) -> String {
messages
.iter()
.filter_map(|m| {
if let AgentMessage::Llm(opi_ai::message::Message::ToolResult(trm)) = m {
Some(trm.content.clone())
} else {
None
}
})
.flat_map(|c| {
c.into_iter().filter_map(|c| match c {
OutputContent::Text { text } => Some(text),
_ => None,
})
})
.collect()
}
#[tokio::test]
async fn allow_all_policy_permits_tool_call() {
let ext = PermissionGateExtension::new(PermissionPolicy::AllowAll);
let audit = ext.audit_log.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let provider = MockProvider::new(
"mock",
vec![
tool_call_response("tc_1", "write", r#"{"path":"/tmp/f","content":"x"}"#),
text_response("Done"),
],
);
let hooks = registry.wrap_hooks(Box::new(TestHooks));
let mut agent = opi_agent::Agent::new(
Box::new(provider),
vec![Box::new(DummyTool::new("write"))],
"mock:model".into(),
None,
AgentLoopConfig {
max_turns: 10,
..Default::default()
},
hooks,
);
let result = agent.prompt("test").await.unwrap();
assert!(result.len() >= 3);
let tool_text = extract_tool_result_text(&result);
assert!(
tool_text.contains("ok"),
"tool should have executed, got: {tool_text}"
);
let log = audit.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0].tool_name, "write");
assert_eq!(log[0].decision, "allowed");
assert!(log[0].reason.is_none());
}
#[tokio::test]
async fn allow_list_policy_permits_listed_tool() {
let ext = PermissionGateExtension::new(PermissionPolicy::AllowList(vec![
"read".to_string(),
"glob".to_string(),
]));
let audit = ext.audit_log.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let provider = MockProvider::new(
"mock",
vec![
tool_call_response("tc_1", "read", r#"{"path":"/tmp/f"}"#),
text_response("Done"),
],
);
let hooks = registry.wrap_hooks(Box::new(TestHooks));
let mut agent = opi_agent::Agent::new(
Box::new(provider),
vec![Box::new(DummyTool::new("read"))],
"mock:model".into(),
None,
AgentLoopConfig {
max_turns: 10,
..Default::default()
},
hooks,
);
let result = agent.prompt("test").await.unwrap();
let tool_text = extract_tool_result_text(&result);
assert!(tool_text.contains("ok"), "listed tool should execute");
let log = audit.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0].decision, "allowed");
}
#[tokio::test]
async fn deny_all_policy_blocks_tool_call() {
let ext = PermissionGateExtension::new(PermissionPolicy::DenyAll);
let audit = ext.audit_log.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let provider = MockProvider::new(
"mock",
vec![
tool_call_response("tc_1", "bash", r#"{"command":"rm -rf /"}"#),
text_response("Done"),
],
);
let hooks = registry.wrap_hooks(Box::new(TestHooks));
let mut agent = opi_agent::Agent::new(
Box::new(provider),
vec![Box::new(DummyTool::new("bash"))],
"mock:model".into(),
None,
AgentLoopConfig {
max_turns: 10,
..Default::default()
},
hooks,
);
let result = agent.prompt("test").await.unwrap();
let tool_text = extract_tool_result_text(&result);
assert!(
tool_text.contains("permission gate denied"),
"tool should be blocked, got: {tool_text}"
);
assert!(!tool_text.contains("ok"), "tool should NOT have executed");
let log = audit.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0].tool_name, "bash");
assert_eq!(log[0].decision, "denied");
assert!(log[0].reason.is_some());
}
#[tokio::test]
async fn deny_list_policy_blocks_specific_tools() {
let ext = PermissionGateExtension::new(PermissionPolicy::DenyList(vec![
"write".to_string(),
"edit".to_string(),
"bash".to_string(),
]));
let audit = ext.audit_log.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let provider = MockProvider::new(
"mock",
vec![
tool_call_response("tc_1", "write", r#"{"path":"/tmp/f","content":"x"}"#),
text_response("Done"),
],
);
let hooks = registry.wrap_hooks(Box::new(TestHooks));
let mut agent = opi_agent::Agent::new(
Box::new(provider),
vec![Box::new(DummyTool::new("write"))],
"mock:model".into(),
None,
AgentLoopConfig {
max_turns: 10,
..Default::default()
},
hooks,
);
let result = agent.prompt("test").await.unwrap();
let tool_text = extract_tool_result_text(&result);
assert!(
tool_text.contains("permission gate denied"),
"denied tool should be blocked, got: {tool_text}"
);
let log = audit.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0].tool_name, "write");
assert_eq!(log[0].decision, "denied");
}
#[tokio::test]
async fn deny_list_policy_allows_non_listed_tools() {
let ext = PermissionGateExtension::new(PermissionPolicy::DenyList(vec![
"write".to_string(),
"edit".to_string(),
"bash".to_string(),
]));
let audit = ext.audit_log.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let provider = MockProvider::new(
"mock",
vec![
tool_call_response("tc_1", "read", r#"{"path":"/tmp/f"}"#),
text_response("Done"),
],
);
let hooks = registry.wrap_hooks(Box::new(TestHooks));
let mut agent = opi_agent::Agent::new(
Box::new(provider),
vec![Box::new(DummyTool::new("read"))],
"mock:model".into(),
None,
AgentLoopConfig {
max_turns: 10,
..Default::default()
},
hooks,
);
let result = agent.prompt("test").await.unwrap();
let tool_text = extract_tool_result_text(&result);
assert!(
tool_text.contains("ok"),
"non-listed tool should execute, got: {tool_text}"
);
let log = audit.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0].decision, "allowed");
}
#[tokio::test]
async fn audit_log_records_allow_and_deny_across_turns() {
let ext = PermissionGateExtension::new(PermissionPolicy::AllowList(vec!["read".to_string()]));
let audit = ext.audit_log.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let provider = MockProvider::new(
"mock",
vec![
tool_call_response("tc_1", "read", r#"{"path":"/tmp/f"}"#),
tool_call_response("tc_2", "write", r#"{"path":"/tmp/f","content":"x"}"#),
text_response("Done"),
],
);
let hooks = registry.wrap_hooks(Box::new(TestHooks));
let mut agent = opi_agent::Agent::new(
Box::new(provider),
vec![
Box::new(DummyTool::new("read")),
Box::new(DummyTool::new("write")),
],
"mock:model".into(),
None,
AgentLoopConfig {
max_turns: 10,
..Default::default()
},
hooks,
);
let _ = agent.prompt("test").await.unwrap();
let log = audit.lock().unwrap();
assert!(log.len() >= 2, "should have at least 2 audit entries");
assert_eq!(log[0].tool_name, "read");
assert_eq!(log[0].decision, "allowed");
assert_eq!(log[1].tool_name, "write");
assert_eq!(log[1].decision, "denied");
assert!(log[1].reason.is_some());
}
#[tokio::test]
async fn extension_receives_agent_events() {
let ext = PermissionGateExtension::new(PermissionPolicy::AllowAll);
let events = ext.events_received.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let base_sink =
Box::new(|_event: opi_agent::event::AgentEvent| {}) as opi_agent::event::AgentEventSink;
let wrapped_sink = registry.wrap_event_sink(base_sink);
wrapped_sink(opi_agent::event::AgentEvent::AgentStart);
wrapped_sink(opi_agent::event::AgentEvent::TurnStart);
wrapped_sink(opi_agent::event::AgentEvent::ToolExecutionStart {
tool_call_id: "tc_1".into(),
tool_name: "read".into(),
args: serde_json::json!({}),
});
let received = events.lock().unwrap();
assert!(
received.contains(&"AgentStart".to_string()),
"should have received AgentStart"
);
assert!(
received.contains(&"TurnStart".to_string()),
"should have received TurnStart"
);
assert!(
received.contains(&"ToolExecutionStart(read)".to_string()),
"should have received ToolExecutionStart(read)"
);
}
#[tokio::test]
async fn audit_state_round_trips_through_serialization() {
let ext = PermissionGateExtension::new(PermissionPolicy::DenyList(vec!["bash".to_string()]));
ext.evaluate("read");
ext.evaluate("bash");
let state = ext.serialize_state().unwrap().unwrap();
assert_eq!(state["audit_log"].as_array().unwrap().len(), 2);
let ext2 = PermissionGateExtension::new(PermissionPolicy::AllowAll);
ext2.restore_state(state).unwrap();
let log = ext2.audit_log.lock().unwrap();
assert_eq!(log.len(), 2);
assert_eq!(log[0].tool_name, "read");
assert_eq!(log[0].decision, "allowed");
assert_eq!(log[1].tool_name, "bash");
assert_eq!(log[1].decision, "denied");
}
#[tokio::test]
async fn non_interactive_auto_approves_with_allow_all() {
let ext = PermissionGateExtension::new(PermissionPolicy::AllowAll);
let audit = ext.audit_log.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let provider = MockProvider::new(
"mock",
vec![
tool_call_response("tc_1", "write", r#"{"path":"/tmp/a","content":"data"}"#),
tool_call_response("tc_2", "bash", r#"{"command":"ls"}"#),
text_response("Done"),
],
);
let hooks = registry.wrap_hooks(Box::new(TestHooks));
let mut agent = opi_agent::Agent::new(
Box::new(provider),
vec![
Box::new(DummyTool::new("write")),
Box::new(DummyTool::new("bash")),
],
"mock:model".into(),
None,
AgentLoopConfig {
max_turns: 10,
..Default::default()
},
hooks,
);
let result = agent.prompt("test").await.unwrap();
assert!(result.len() >= 3);
let log = audit.lock().unwrap();
assert!(log.len() >= 2, "should have audited both tool calls");
assert!(log.iter().all(|e| e.decision == "allowed"));
}
#[tokio::test]
async fn non_interactive_auto_denies_with_deny_all() {
let ext = PermissionGateExtension::new(PermissionPolicy::DenyAll);
let audit = ext.audit_log.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let provider = MockProvider::new(
"mock",
vec![
tool_call_response("tc_1", "write", r#"{"path":"/tmp/a","content":"data"}"#),
text_response("Done"),
],
);
let hooks = registry.wrap_hooks(Box::new(TestHooks));
let mut agent = opi_agent::Agent::new(
Box::new(provider),
vec![Box::new(DummyTool::new("write"))],
"mock:model".into(),
None,
AgentLoopConfig {
max_turns: 10,
..Default::default()
},
hooks,
);
let result = agent.prompt("test").await.unwrap();
let tool_text = extract_tool_result_text(&result);
assert!(
tool_text.contains("permission gate denied"),
"should be auto-denied, got: {tool_text}"
);
let log = audit.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0].decision, "denied");
}
#[tokio::test]
async fn non_interactive_auto_denies_with_allow_list_for_unlisted() {
let ext = PermissionGateExtension::new(PermissionPolicy::AllowList(vec!["read".to_string()]));
let audit = ext.audit_log.clone();
let mut registry = ExtensionRegistry::new();
registry.register(Box::new(ext)).unwrap();
let provider = MockProvider::new(
"mock",
vec![
tool_call_response("tc_1", "bash", r#"{"command":"ls"}"#),
text_response("Done"),
],
);
let hooks = registry.wrap_hooks(Box::new(TestHooks));
let mut agent = opi_agent::Agent::new(
Box::new(provider),
vec![Box::new(DummyTool::new("bash"))],
"mock:model".into(),
None,
AgentLoopConfig {
max_turns: 10,
..Default::default()
},
hooks,
);
let result = agent.prompt("test").await.unwrap();
let tool_text = extract_tool_result_text(&result);
assert!(
tool_text.contains("permission gate denied"),
"unlisted tool should be auto-denied, got: {tool_text}"
);
let log = audit.lock().unwrap();
assert_eq!(log.len(), 1);
assert_eq!(log[0].tool_name, "bash");
assert_eq!(log[0].decision, "denied");
}