use async_trait::async_trait;
use crate::reasoning::loop_types::{LoopDecision, LoopState, ProposedAction};
use crate::types::AgentId;
#[async_trait]
pub trait ReasoningPolicyGate: Send + Sync {
async fn evaluate_action(
&self,
agent_id: &AgentId,
action: &ProposedAction,
state: &LoopState,
) -> LoopDecision;
}
pub struct DefaultPolicyGate {
allow_all: bool,
}
impl DefaultPolicyGate {
pub fn new() -> Self {
Self { allow_all: false }
}
pub fn permissive() -> Self {
Self { allow_all: true }
}
}
impl Default for DefaultPolicyGate {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl ReasoningPolicyGate for DefaultPolicyGate {
async fn evaluate_action(
&self,
agent_id: &AgentId,
action: &ProposedAction,
_state: &LoopState,
) -> LoopDecision {
if self.allow_all {
return LoopDecision::Allow;
}
match action {
ProposedAction::ToolCall {
name, arguments, ..
} => {
let input = serde_json::json!({
"type": "tool_call",
"agent_id": agent_id.to_string(),
"tool_name": name,
"arguments": arguments,
});
tracing::debug!(
"Policy gate evaluating tool call: agent={} tool={}",
agent_id,
name
);
let _ = input; LoopDecision::Allow
}
ProposedAction::Delegate { target, .. } => {
tracing::debug!(
"Policy gate evaluating delegation: agent={} target={}",
agent_id,
target
);
LoopDecision::Allow
}
ProposedAction::Respond { .. } => {
LoopDecision::Allow
}
ProposedAction::Terminate { .. } => {
LoopDecision::Allow
}
}
}
}
pub struct OpaPolicyGateBridge {
policy_engine: std::sync::Arc<dyn crate::integrations::policy_engine::engine::PolicyEngine>,
}
impl OpaPolicyGateBridge {
pub fn new(
engine: std::sync::Arc<dyn crate::integrations::policy_engine::engine::PolicyEngine>,
) -> Self {
Self {
policy_engine: engine,
}
}
}
#[async_trait]
impl ReasoningPolicyGate for OpaPolicyGateBridge {
async fn evaluate_action(
&self,
agent_id: &AgentId,
action: &ProposedAction,
_state: &LoopState,
) -> LoopDecision {
let input = match action {
ProposedAction::ToolCall {
name,
arguments,
call_id,
} => serde_json::json!({
"type": "tool_call",
"call_id": call_id,
"tool_name": name,
"arguments": arguments,
}),
ProposedAction::Delegate { target, message } => serde_json::json!({
"type": "delegate",
"target": target,
"message_length": message.len(),
}),
ProposedAction::Respond { content } => serde_json::json!({
"type": "respond",
"content_length": content.len(),
}),
ProposedAction::Terminate { reason, .. } => serde_json::json!({
"type": "terminate",
"reason": reason,
}),
};
match self
.policy_engine
.evaluate_policy(&agent_id.to_string(), &input)
.await
{
Ok(crate::integrations::policy_engine::engine::PolicyDecision::Allow) => {
LoopDecision::Allow
}
Ok(crate::integrations::policy_engine::engine::PolicyDecision::Deny) => {
let reason = format!(
"Policy denied action {:?} for agent {}",
std::mem::discriminant(action),
agent_id
);
tracing::warn!("{}", reason);
LoopDecision::Deny { reason }
}
Err(e) => {
let reason = format!("Policy evaluation error: {}", e);
tracing::error!("{}", reason);
LoopDecision::Deny { reason }
}
}
}
}
pub struct ToolFilterPolicyGate {
allowed_tools: std::collections::HashSet<String>,
allow_all: bool,
}
impl ToolFilterPolicyGate {
pub fn allow(tools: &[&str]) -> Self {
Self {
allowed_tools: tools.iter().map(|s| s.to_string()).collect(),
allow_all: false,
}
}
pub fn allow_all() -> Self {
Self {
allowed_tools: std::collections::HashSet::new(),
allow_all: true,
}
}
}
#[async_trait]
impl ReasoningPolicyGate for ToolFilterPolicyGate {
async fn evaluate_action(
&self,
_agent_id: &AgentId,
action: &ProposedAction,
_state: &LoopState,
) -> LoopDecision {
if self.allow_all {
return LoopDecision::Allow;
}
match action {
ProposedAction::ToolCall { name, .. } => {
if self.allowed_tools.contains(name.as_str()) {
LoopDecision::Allow
} else {
LoopDecision::Deny {
reason: format!("Tool '{}' not in allowed list", name),
}
}
}
_ => LoopDecision::Allow,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reasoning::conversation::Conversation;
use crate::reasoning::loop_types::LoopState;
#[tokio::test]
async fn test_default_gate_allows_all_actions() {
let gate = DefaultPolicyGate::permissive();
let agent_id = AgentId::new();
let state = LoopState::new(agent_id, Conversation::new());
let tool_call = ProposedAction::ToolCall {
call_id: "c1".into(),
name: "search".into(),
arguments: "{}".into(),
};
let decision = gate.evaluate_action(&agent_id, &tool_call, &state).await;
assert!(matches!(decision, LoopDecision::Allow));
let delegate = ProposedAction::Delegate {
target: "other_agent".into(),
message: "hello".into(),
};
let decision = gate.evaluate_action(&agent_id, &delegate, &state).await;
assert!(matches!(decision, LoopDecision::Allow));
let respond = ProposedAction::Respond {
content: "done".into(),
};
let decision = gate.evaluate_action(&agent_id, &respond, &state).await;
assert!(matches!(decision, LoopDecision::Allow));
let terminate = ProposedAction::Terminate {
reason: "done".into(),
output: "result".into(),
};
let decision = gate.evaluate_action(&agent_id, &terminate, &state).await;
assert!(matches!(decision, LoopDecision::Allow));
}
#[tokio::test]
async fn test_default_gate_standard_mode() {
let gate = DefaultPolicyGate::new();
let agent_id = AgentId::new();
let state = LoopState::new(agent_id, Conversation::new());
let tool_call = ProposedAction::ToolCall {
call_id: "c1".into(),
name: "search".into(),
arguments: "{}".into(),
};
let decision = gate.evaluate_action(&agent_id, &tool_call, &state).await;
assert!(matches!(decision, LoopDecision::Allow));
}
#[tokio::test]
async fn test_tool_filter_allows_whitelisted_tools() {
let gate = ToolFilterPolicyGate::allow(&["search", "calculator"]);
let agent_id = AgentId::new();
let state = LoopState::new(agent_id, Conversation::new());
let allowed = ProposedAction::ToolCall {
call_id: "c1".into(),
name: "search".into(),
arguments: "{}".into(),
};
let decision = gate.evaluate_action(&agent_id, &allowed, &state).await;
assert!(matches!(decision, LoopDecision::Allow));
}
#[tokio::test]
async fn test_tool_filter_denies_non_whitelisted_tools() {
let gate = ToolFilterPolicyGate::allow(&["search"]);
let agent_id = AgentId::new();
let state = LoopState::new(agent_id, Conversation::new());
let denied = ProposedAction::ToolCall {
call_id: "c1".into(),
name: "delete_everything".into(),
arguments: "{}".into(),
};
let decision = gate.evaluate_action(&agent_id, &denied, &state).await;
assert!(matches!(decision, LoopDecision::Deny { .. }));
if let LoopDecision::Deny { reason } = decision {
assert!(reason.contains("delete_everything"));
assert!(reason.contains("not in allowed list"));
}
}
#[tokio::test]
async fn test_tool_filter_allows_non_tool_actions() {
let gate = ToolFilterPolicyGate::allow(&["search"]);
let agent_id = AgentId::new();
let state = LoopState::new(agent_id, Conversation::new());
let respond = ProposedAction::Respond {
content: "hello".into(),
};
let decision = gate.evaluate_action(&agent_id, &respond, &state).await;
assert!(matches!(decision, LoopDecision::Allow));
let delegate = ProposedAction::Delegate {
target: "other".into(),
message: "hi".into(),
};
let decision = gate.evaluate_action(&agent_id, &delegate, &state).await;
assert!(matches!(decision, LoopDecision::Allow));
let terminate = ProposedAction::Terminate {
reason: "done".into(),
output: "result".into(),
};
let decision = gate.evaluate_action(&agent_id, &terminate, &state).await;
assert!(matches!(decision, LoopDecision::Allow));
}
#[tokio::test]
async fn test_tool_filter_allow_all() {
let gate = ToolFilterPolicyGate::allow_all();
let agent_id = AgentId::new();
let state = LoopState::new(agent_id, Conversation::new());
let tool_call = ProposedAction::ToolCall {
call_id: "c1".into(),
name: "anything".into(),
arguments: "{}".into(),
};
let decision = gate.evaluate_action(&agent_id, &tool_call, &state).await;
assert!(matches!(decision, LoopDecision::Allow));
}
}