Skip to main content

llmtrace_security/
action_policy.rs

1//! Action-selector policy enforcement and context minimization.
2//!
3//! Implements two patterns from "Design Patterns for Securing LLM Agents"
4//! (IBM/EPFL/ETH/Google/Microsoft):
5//!
6//! 1. **Action-Selector Pattern** — enforce that agents can only invoke tools
7//!    from a predefined allowlist. Any tool call not on the list is blocked.
8//! 2. **Context-Minimization Pattern** — strip unnecessary context from
9//!    requests to reduce attack surface.
10//!
11//! The [`PolicyEngine`] orchestrates multiple [`ActionPolicy`] instances and
12//! a [`ContextMinimizer`], producing [`PolicyDecision`]s with attached
13//! [`SecurityFinding`]s.
14//!
15//! # Example
16//!
17//! ```
18//! use llmtrace_security::action_policy::{ActionPolicy, PolicyEngine, ContextMinimizer, Message};
19//! use llmtrace_core::{AgentAction, AgentActionType};
20//!
21//! let mut engine = PolicyEngine::new();
22//! engine.add_policy(ActionPolicy::restrictive("prod", "Production Policy"));
23//!
24//! let action = AgentAction::new(AgentActionType::ToolCall, "unknown_tool".to_string());
25//! let decision = engine.evaluate_action(&action, None, "session-1");
26//! assert!(decision.is_denied());
27//!
28//! let messages = vec![
29//!     Message::new("system", "You are helpful."),
30//!     Message::new("user", "Hello"),
31//! ];
32//! let minimized = engine.minimize_context(&messages);
33//! assert!(!minimized.is_empty());
34//! ```
35
36use crate::tool_registry::ToolDefinition;
37use llmtrace_core::{AgentAction, AgentActionType, SecurityFinding, SecuritySeverity};
38use regex::Regex;
39use std::collections::{HashMap, HashSet};
40use std::fmt;
41use std::sync::RwLock;
42
43// ---------------------------------------------------------------------------
44// EnforcementMode
45// ---------------------------------------------------------------------------
46
47/// Enforcement mode for action policies.
48///
49/// Controls whether violations are logged, blocked, or handled adaptively
50/// based on risk level.
51#[derive(Debug, Clone, PartialEq, Eq)]
52pub enum EnforcementMode {
53    /// Log violations but allow all actions.
54    Audit,
55    /// Block violations and return a deny decision.
56    Enforce,
57    /// Block only high-risk violations, warn on others.
58    Adaptive,
59}
60
61impl fmt::Display for EnforcementMode {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        match self {
64            Self::Audit => write!(f, "audit"),
65            Self::Enforce => write!(f, "enforce"),
66            Self::Adaptive => write!(f, "adaptive"),
67        }
68    }
69}
70
71// ---------------------------------------------------------------------------
72// PolicyVerdict
73// ---------------------------------------------------------------------------
74
75/// Verdict from a single policy evaluation.
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum PolicyVerdict {
78    /// Action is allowed.
79    Allow,
80    /// Action is denied with a reason.
81    Deny(String),
82    /// Action is allowed but with a warning.
83    Warn(String),
84}
85
86impl fmt::Display for PolicyVerdict {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        match self {
89            Self::Allow => write!(f, "allow"),
90            Self::Deny(reason) => write!(f, "deny: {}", reason),
91            Self::Warn(reason) => write!(f, "warn: {}", reason),
92        }
93    }
94}
95
96// ---------------------------------------------------------------------------
97// PolicyDecision
98// ---------------------------------------------------------------------------
99
100/// Full decision from policy evaluation, including verdict and findings.
101#[derive(Debug, Clone)]
102pub struct PolicyDecision {
103    /// The overall verdict.
104    pub verdict: PolicyVerdict,
105    /// Security findings produced during evaluation.
106    pub findings: Vec<SecurityFinding>,
107    /// Which policy ID produced the decision (empty if from the engine).
108    pub policy_id: String,
109}
110
111impl PolicyDecision {
112    /// Create an allow decision with no findings.
113    pub fn allow() -> Self {
114        Self {
115            verdict: PolicyVerdict::Allow,
116            findings: Vec::new(),
117            policy_id: String::new(),
118        }
119    }
120
121    /// Create a deny decision with a reason and findings.
122    pub fn deny(reason: String, findings: Vec<SecurityFinding>, policy_id: String) -> Self {
123        Self {
124            verdict: PolicyVerdict::Deny(reason),
125            findings,
126            policy_id,
127        }
128    }
129
130    /// Create a warn decision with a reason and findings.
131    pub fn warn(reason: String, findings: Vec<SecurityFinding>, policy_id: String) -> Self {
132        Self {
133            verdict: PolicyVerdict::Warn(reason),
134            findings,
135            policy_id,
136        }
137    }
138
139    /// Returns `true` if the verdict is [`PolicyVerdict::Allow`].
140    pub fn is_allowed(&self) -> bool {
141        matches!(self.verdict, PolicyVerdict::Allow)
142    }
143
144    /// Returns `true` if the verdict is [`PolicyVerdict::Deny`].
145    pub fn is_denied(&self) -> bool {
146        matches!(self.verdict, PolicyVerdict::Deny(_))
147    }
148
149    /// Returns `true` if the verdict is [`PolicyVerdict::Warn`].
150    pub fn is_warned(&self) -> bool {
151        matches!(self.verdict, PolicyVerdict::Warn(_))
152    }
153}
154
155// ---------------------------------------------------------------------------
156// ActionPolicy
157// ---------------------------------------------------------------------------
158
159/// Policy for controlling which actions an agent can take.
160///
161/// Combines allowlist/blocklist enforcement, risk score thresholds, action
162/// type filtering, and session-level rate limiting into a single evaluable
163/// policy.
164///
165/// Use the builder methods to configure, or the convenience constructors
166/// [`ActionPolicy::permissive`] and [`ActionPolicy::restrictive`].
167#[derive(Debug, Clone)]
168pub struct ActionPolicy {
169    /// Policy identifier.
170    pub id: String,
171    /// Human-readable name.
172    pub name: String,
173    /// Enforcement mode.
174    pub mode: EnforcementMode,
175    /// Allowed tool IDs (if set, only these tools are permitted).
176    pub allowed_tools: Option<HashSet<String>>,
177    /// Blocked tool IDs (these tools are always denied).
178    pub blocked_tools: HashSet<String>,
179    /// Maximum risk score allowed (tools with higher risk are blocked).
180    pub max_risk_score: f64,
181    /// Allowed action types (if set, only these types are permitted).
182    pub allowed_action_types: Option<HashSet<AgentActionType>>,
183    /// Maximum total actions per session.
184    pub max_actions_per_session: Option<u32>,
185    /// Whether to allow actions on unregistered tools.
186    pub allow_unregistered: bool,
187}
188
189impl ActionPolicy {
190    /// Create a new action policy with sensible defaults.
191    ///
192    /// Defaults: enforce mode, no allowlist, no blocklist, max risk 1.0,
193    /// all action types allowed, no session limit, unregistered tools allowed.
194    pub fn new(id: &str, name: &str) -> Self {
195        Self {
196            id: id.to_string(),
197            name: name.to_string(),
198            mode: EnforcementMode::Enforce,
199            allowed_tools: None,
200            blocked_tools: HashSet::new(),
201            max_risk_score: 1.0,
202            allowed_action_types: None,
203            max_actions_per_session: None,
204            allow_unregistered: true,
205        }
206    }
207
208    /// Create a permissive policy that allows everything in audit mode.
209    ///
210    /// All actions are allowed; violations are only logged as findings.
211    pub fn permissive(id: &str, name: &str) -> Self {
212        Self {
213            id: id.to_string(),
214            name: name.to_string(),
215            mode: EnforcementMode::Audit,
216            allowed_tools: None,
217            blocked_tools: HashSet::new(),
218            max_risk_score: 1.0,
219            allowed_action_types: None,
220            max_actions_per_session: None,
221            allow_unregistered: true,
222        }
223    }
224
225    /// Create a restrictive policy that denies by default.
226    ///
227    /// Requires explicit allowlist, blocks unregistered tools, and enforces
228    /// a conservative risk threshold of 0.7.
229    pub fn restrictive(id: &str, name: &str) -> Self {
230        Self {
231            id: id.to_string(),
232            name: name.to_string(),
233            mode: EnforcementMode::Enforce,
234            allowed_tools: Some(HashSet::new()),
235            blocked_tools: HashSet::new(),
236            max_risk_score: 0.7,
237            allowed_action_types: None,
238            max_actions_per_session: None,
239            allow_unregistered: false,
240        }
241    }
242
243    /// Set the enforcement mode.
244    pub fn with_mode(mut self, mode: EnforcementMode) -> Self {
245        self.mode = mode;
246        self
247    }
248
249    /// Set the allowed tool IDs. Only these tools will be permitted.
250    pub fn with_allowed_tools(mut self, tools: HashSet<String>) -> Self {
251        self.allowed_tools = Some(tools);
252        self
253    }
254
255    /// Set the blocked tool IDs. These tools are always denied.
256    pub fn with_blocked_tools(mut self, tools: HashSet<String>) -> Self {
257        self.blocked_tools = tools;
258        self
259    }
260
261    /// Set the maximum risk score allowed.
262    pub fn with_max_risk_score(mut self, score: f64) -> Self {
263        self.max_risk_score = score.clamp(0.0, 1.0);
264        self
265    }
266
267    /// Set the allowed action types.
268    pub fn with_allowed_action_types(mut self, types: HashSet<AgentActionType>) -> Self {
269        self.allowed_action_types = Some(types);
270        self
271    }
272
273    /// Set the maximum actions per session.
274    pub fn with_max_actions_per_session(mut self, max: u32) -> Self {
275        self.max_actions_per_session = Some(max);
276        self
277    }
278
279    /// Set whether unregistered tools are allowed.
280    pub fn with_allow_unregistered(mut self, allow: bool) -> Self {
281        self.allow_unregistered = allow;
282        self
283    }
284
285    /// Evaluate an action against this policy.
286    ///
287    /// Returns a [`PolicyDecision`] indicating whether the action is allowed,
288    /// denied, or warned, along with any [`SecurityFinding`]s.
289    pub fn evaluate(
290        &self,
291        action: &AgentAction,
292        tool_def: Option<&ToolDefinition>,
293    ) -> PolicyDecision {
294        let mut findings = Vec::new();
295        let mut violations: Vec<String> = Vec::new();
296
297        // 1. Check action type allowlist
298        if let Some(ref allowed_types) = self.allowed_action_types {
299            if !allowed_types.contains(&action.action_type) {
300                let reason = format!("Action type '{}' not in allowed types", action.action_type);
301                violations.push(reason.clone());
302                findings.push(self.make_finding(
303                    SecuritySeverity::High,
304                    "action_type_blocked",
305                    &reason,
306                    &action.name,
307                ));
308            }
309        }
310
311        // 2. Check blocklist
312        if self.blocked_tools.contains(&action.name) {
313            let reason = format!("Tool '{}' is on the blocklist", action.name);
314            violations.push(reason.clone());
315            findings.push(self.make_finding(
316                SecuritySeverity::High,
317                "tool_blocked",
318                &reason,
319                &action.name,
320            ));
321        }
322
323        // 3. Check allowlist (only for tool calls and skill invocations)
324        if let Some(ref allowed) = self.allowed_tools {
325            let is_tool_like = action.action_type == AgentActionType::ToolCall
326                || action.action_type == AgentActionType::SkillInvocation;
327            if is_tool_like && !allowed.contains(&action.name) {
328                let reason = format!("Tool '{}' not in allowlist", action.name);
329                violations.push(reason.clone());
330                findings.push(self.make_finding(
331                    SecuritySeverity::High,
332                    "tool_not_allowed",
333                    &reason,
334                    &action.name,
335                ));
336            }
337        }
338
339        // 4. Check unregistered tool
340        if !self.allow_unregistered && tool_def.is_none() {
341            let is_tool_like = action.action_type == AgentActionType::ToolCall
342                || action.action_type == AgentActionType::SkillInvocation;
343            if is_tool_like {
344                let reason = format!("Unregistered tool '{}' not allowed", action.name);
345                violations.push(reason.clone());
346                findings.push(self.make_finding(
347                    SecuritySeverity::High,
348                    "unregistered_tool_blocked",
349                    &reason,
350                    &action.name,
351                ));
352            }
353        }
354
355        // 5. Check risk score
356        if let Some(tool) = tool_def {
357            if tool.risk_score > self.max_risk_score {
358                let reason = format!(
359                    "Tool '{}' risk score {:.2} exceeds max {:.2}",
360                    action.name, tool.risk_score, self.max_risk_score
361                );
362                violations.push(reason.clone());
363                findings.push(self.make_finding(
364                    SecuritySeverity::High,
365                    "risk_score_exceeded",
366                    &reason,
367                    &action.name,
368                ));
369            }
370        }
371
372        // Apply enforcement mode
373        if violations.is_empty() {
374            PolicyDecision {
375                verdict: PolicyVerdict::Allow,
376                findings,
377                policy_id: self.id.clone(),
378            }
379        } else {
380            let combined_reason = violations.join("; ");
381            match self.mode {
382                EnforcementMode::Audit => {
383                    // Log but allow
384                    PolicyDecision {
385                        verdict: PolicyVerdict::Warn(combined_reason),
386                        findings,
387                        policy_id: self.id.clone(),
388                    }
389                }
390                EnforcementMode::Enforce => PolicyDecision {
391                    verdict: PolicyVerdict::Deny(combined_reason),
392                    findings,
393                    policy_id: self.id.clone(),
394                },
395                EnforcementMode::Adaptive => {
396                    // Block if any finding is High or Critical, warn otherwise
397                    let has_high = findings
398                        .iter()
399                        .any(|f| f.severity >= SecuritySeverity::High);
400                    if has_high {
401                        PolicyDecision {
402                            verdict: PolicyVerdict::Deny(combined_reason),
403                            findings,
404                            policy_id: self.id.clone(),
405                        }
406                    } else {
407                        PolicyDecision {
408                            verdict: PolicyVerdict::Warn(combined_reason),
409                            findings,
410                            policy_id: self.id.clone(),
411                        }
412                    }
413                }
414            }
415        }
416    }
417
418    /// Create a security finding for a policy violation.
419    fn make_finding(
420        &self,
421        severity: SecuritySeverity,
422        finding_type: &str,
423        description: &str,
424        tool_name: &str,
425    ) -> SecurityFinding {
426        SecurityFinding::new(
427            severity,
428            format!("policy_{}", finding_type),
429            format!("[{}] {}", self.name, description),
430            0.95,
431        )
432        .with_location("action_policy".to_string())
433        .with_metadata("policy_id".to_string(), self.id.clone())
434        .with_metadata("policy_name".to_string(), self.name.clone())
435        .with_metadata("tool_name".to_string(), tool_name.to_string())
436        .with_metadata("enforcement_mode".to_string(), self.mode.to_string())
437    }
438}
439
440// ---------------------------------------------------------------------------
441// Message
442// ---------------------------------------------------------------------------
443
444/// A simple message with a role and content, used for context minimization.
445#[derive(Debug, Clone, PartialEq, Eq)]
446pub struct Message {
447    /// Message role (e.g., `"system"`, `"user"`, `"assistant"`, `"tool"`).
448    pub role: String,
449    /// Message content.
450    pub content: String,
451}
452
453impl Message {
454    /// Create a new message.
455    pub fn new(role: &str, content: &str) -> Self {
456        Self {
457            role: role.to_string(),
458            content: content.to_string(),
459        }
460    }
461}
462
463// ---------------------------------------------------------------------------
464// ContextMinimizer
465// ---------------------------------------------------------------------------
466
467/// Strips unnecessary context from LLM request messages to reduce attack surface.
468///
469/// Implements the context-minimization pattern: only the minimum necessary
470/// context is forwarded to the LLM, reducing the window for prompt injection
471/// attacks embedded in prior conversation turns.
472pub struct ContextMinimizer {
473    /// Maximum number of conversation turns to keep.
474    pub max_turns: usize,
475    /// Whether to strip system prompts from forwarded tool contexts.
476    pub strip_system_prompts: bool,
477    /// Whether to strip prior tool results from context.
478    pub strip_prior_tool_results: bool,
479    /// Maximum total context characters.
480    pub max_context_chars: usize,
481    /// Patterns to always strip (compiled regex).
482    strip_patterns: Vec<Regex>,
483}
484
485impl ContextMinimizer {
486    /// Create a new context minimizer with custom settings.
487    pub fn new(
488        max_turns: usize,
489        strip_system_prompts: bool,
490        strip_prior_tool_results: bool,
491        max_context_chars: usize,
492    ) -> Self {
493        Self {
494            max_turns,
495            strip_system_prompts,
496            strip_prior_tool_results,
497            max_context_chars,
498            strip_patterns: Self::default_strip_patterns(),
499        }
500    }
501
502    /// Add a custom strip pattern.
503    pub fn with_strip_pattern(mut self, pattern: &str) -> Self {
504        if let Ok(re) = Regex::new(pattern) {
505            self.strip_patterns.push(re);
506        }
507        self
508    }
509
510    /// Build the default set of strip patterns.
511    ///
512    /// These remove common injection payloads and sensitive metadata that
513    /// should not be forwarded in context.
514    fn default_strip_patterns() -> Vec<Regex> {
515        let patterns = [
516            // API keys and tokens in context
517            r"(?i)(api[_\s]?key|secret[_\s]?key|auth[_\s]?token)\s*[:=]\s*\S+",
518            // Bearer tokens
519            r"(?i)Bearer\s+[A-Za-z0-9\-._~+/]+=*",
520            // Connection strings
521            r"(?i)(mongodb|postgres|mysql|redis)://\S+",
522        ];
523        patterns.iter().filter_map(|p| Regex::new(p).ok()).collect()
524    }
525
526    /// Minimize a sequence of messages according to the configured policy.
527    ///
528    /// Applies the following transformations in order:
529    /// 1. Strip system prompts (if configured)
530    /// 2. Strip tool results (if configured)
531    /// 3. Keep only the last `max_turns` user/assistant pairs
532    /// 4. Apply strip patterns to all remaining content
533    /// 5. Truncate to `max_context_chars` total
534    pub fn minimize_context(&self, messages: &[Message]) -> Vec<Message> {
535        let mut result: Vec<Message> = Vec::new();
536
537        // Phase 1: Filter by role
538        for msg in messages {
539            if self.strip_system_prompts && msg.role == "system" {
540                continue;
541            }
542            if self.strip_prior_tool_results && msg.role == "tool" {
543                continue;
544            }
545            result.push(msg.clone());
546        }
547
548        // Phase 2: Keep only the last max_turns conversation turns.
549        // A "turn" is a user message followed by an assistant response.
550        if result.len() > self.max_turns {
551            let skip = result.len() - self.max_turns;
552            result = result.into_iter().skip(skip).collect();
553        }
554
555        // Phase 3: Apply strip patterns to content
556        for msg in &mut result {
557            msg.content = self.minimize_text(&msg.content);
558        }
559
560        // Phase 4: Truncate to max_context_chars
561        let mut total_chars: usize = 0;
562        let mut truncated_result: Vec<Message> = Vec::new();
563        for msg in result {
564            let msg_chars = msg.content.chars().count();
565            if total_chars + msg_chars > self.max_context_chars {
566                let remaining = self.max_context_chars.saturating_sub(total_chars);
567                if remaining > 0 {
568                    let truncated_content: String = msg.content.chars().take(remaining).collect();
569                    truncated_result.push(Message {
570                        role: msg.role,
571                        content: truncated_content,
572                    });
573                }
574                break;
575            }
576            total_chars += msg_chars;
577            truncated_result.push(msg);
578        }
579
580        truncated_result
581    }
582
583    /// Strip patterns from a single text string.
584    pub fn minimize_text(&self, text: &str) -> String {
585        let mut result = text.to_string();
586        for pattern in &self.strip_patterns {
587            result = pattern.replace_all(&result, "[REDACTED]").to_string();
588        }
589        result
590    }
591}
592
593impl Default for ContextMinimizer {
594    /// Create a context minimizer with sensible defaults.
595    ///
596    /// - Keep last 10 turns
597    /// - Strip system prompts from tool contexts
598    /// - Do not strip tool results by default
599    /// - 50,000 character maximum
600    fn default() -> Self {
601        Self::new(10, true, false, 50_000)
602    }
603}
604
605impl fmt::Debug for ContextMinimizer {
606    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
607        f.debug_struct("ContextMinimizer")
608            .field("max_turns", &self.max_turns)
609            .field("strip_system_prompts", &self.strip_system_prompts)
610            .field("strip_prior_tool_results", &self.strip_prior_tool_results)
611            .field("max_context_chars", &self.max_context_chars)
612            .field("strip_pattern_count", &self.strip_patterns.len())
613            .finish()
614    }
615}
616
617// ---------------------------------------------------------------------------
618// PolicyEngine
619// ---------------------------------------------------------------------------
620
621/// Combines multiple [`ActionPolicy`] instances with a [`ContextMinimizer`]
622/// into a single evaluation engine.
623///
624/// Policies are evaluated in order. The first deny verdict wins; if none
625/// deny, the last warn wins; if none warn, the action is allowed. Session
626/// action counters track per-session usage for rate limiting.
627pub struct PolicyEngine {
628    /// Named policies (evaluated in order).
629    policies: Vec<ActionPolicy>,
630    /// Context minimizer.
631    context_minimizer: ContextMinimizer,
632    /// Session action counters: session_id -> count.
633    session_counters: RwLock<HashMap<String, u32>>,
634}
635
636impl PolicyEngine {
637    /// Create a new policy engine with default context minimization and no policies.
638    pub fn new() -> Self {
639        Self {
640            policies: Vec::new(),
641            context_minimizer: ContextMinimizer::default(),
642            session_counters: RwLock::new(HashMap::new()),
643        }
644    }
645
646    /// Create a new policy engine with a custom context minimizer.
647    pub fn with_context_minimizer(context_minimizer: ContextMinimizer) -> Self {
648        Self {
649            policies: Vec::new(),
650            context_minimizer,
651            session_counters: RwLock::new(HashMap::new()),
652        }
653    }
654
655    /// Add a policy to the engine. Policies are evaluated in insertion order.
656    pub fn add_policy(&mut self, policy: ActionPolicy) {
657        self.policies.push(policy);
658    }
659
660    /// Return the number of configured policies.
661    pub fn policy_count(&self) -> usize {
662        self.policies.len()
663    }
664
665    /// Evaluate an action against all configured policies.
666    ///
667    /// Returns a combined [`PolicyDecision`]:
668    /// - First deny verdict wins (short-circuit).
669    /// - If no deny, the last warn verdict is returned.
670    /// - If no violations at all, allow is returned.
671    ///
672    /// Also checks session-level rate limits if any policy has
673    /// `max_actions_per_session` configured.
674    pub fn evaluate_action(
675        &self,
676        action: &AgentAction,
677        tool_def: Option<&ToolDefinition>,
678        session_id: &str,
679    ) -> PolicyDecision {
680        // Check session limits first
681        let session_count = self.get_session_count(session_id);
682
683        let mut all_findings: Vec<SecurityFinding> = Vec::new();
684        let mut last_warn: Option<PolicyDecision> = None;
685
686        for policy in &self.policies {
687            // Check session limit for this policy
688            if let Some(max) = policy.max_actions_per_session {
689                if session_count >= max {
690                    let reason = format!(
691                        "Session '{}' exceeded max actions ({}/{})",
692                        session_id, session_count, max
693                    );
694                    let finding = SecurityFinding::new(
695                        SecuritySeverity::High,
696                        "policy_session_limit_exceeded".to_string(),
697                        format!("[{}] {}", policy.name, reason),
698                        0.95,
699                    )
700                    .with_location("action_policy".to_string())
701                    .with_metadata("policy_id".to_string(), policy.id.clone())
702                    .with_metadata("session_id".to_string(), session_id.to_string())
703                    .with_metadata("session_count".to_string(), session_count.to_string())
704                    .with_metadata("max_actions".to_string(), max.to_string());
705
706                    return match policy.mode {
707                        EnforcementMode::Audit => {
708                            PolicyDecision::warn(reason, vec![finding], policy.id.clone())
709                        }
710                        EnforcementMode::Enforce | EnforcementMode::Adaptive => {
711                            PolicyDecision::deny(reason, vec![finding], policy.id.clone())
712                        }
713                    };
714                }
715            }
716
717            let decision = policy.evaluate(action, tool_def);
718            all_findings.extend(decision.findings.clone());
719
720            match &decision.verdict {
721                PolicyVerdict::Deny(_) => {
722                    // First deny wins — return immediately
723                    return PolicyDecision {
724                        verdict: decision.verdict,
725                        findings: all_findings,
726                        policy_id: decision.policy_id,
727                    };
728                }
729                PolicyVerdict::Warn(_) => {
730                    last_warn = Some(decision);
731                }
732                PolicyVerdict::Allow => {}
733            }
734        }
735
736        // If there were warnings, return the last one with all findings
737        if let Some(warn) = last_warn {
738            return PolicyDecision {
739                verdict: warn.verdict,
740                findings: all_findings,
741                policy_id: warn.policy_id,
742            };
743        }
744
745        // All clear
746        PolicyDecision {
747            verdict: PolicyVerdict::Allow,
748            findings: all_findings,
749            policy_id: String::new(),
750        }
751    }
752
753    /// Minimize a sequence of messages using the configured context minimizer.
754    pub fn minimize_context(&self, messages: &[Message]) -> Vec<Message> {
755        self.context_minimizer.minimize_context(messages)
756    }
757
758    /// Record an action for a session (increment counter).
759    pub fn record_action(&self, session_id: &str) {
760        let mut counters = self
761            .session_counters
762            .write()
763            .expect("session counters lock poisoned");
764        let count = counters.entry(session_id.to_string()).or_insert(0);
765        *count += 1;
766    }
767
768    /// Reset the action counter for a session.
769    pub fn reset_session(&self, session_id: &str) {
770        let mut counters = self
771            .session_counters
772            .write()
773            .expect("session counters lock poisoned");
774        counters.remove(session_id);
775    }
776
777    /// Get the current action count for a session.
778    fn get_session_count(&self, session_id: &str) -> u32 {
779        let counters = self
780            .session_counters
781            .read()
782            .expect("session counters lock poisoned");
783        counters.get(session_id).copied().unwrap_or(0)
784    }
785}
786
787impl Default for PolicyEngine {
788    fn default() -> Self {
789        Self::new()
790    }
791}
792
793impl fmt::Debug for PolicyEngine {
794    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
795        f.debug_struct("PolicyEngine")
796            .field("policy_count", &self.policies.len())
797            .field("context_minimizer", &self.context_minimizer)
798            .finish()
799    }
800}
801
802// ===========================================================================
803// Tests
804// ===========================================================================
805
806#[cfg(test)]
807mod tests {
808    use super::*;
809    use crate::tool_registry::{ToolCategory, ToolDefinition};
810    use llmtrace_core::{AgentAction, AgentActionType};
811
812    // ---------------------------------------------------------------
813    // EnforcementMode
814    // ---------------------------------------------------------------
815
816    #[test]
817    fn test_enforcement_mode_display() {
818        assert_eq!(EnforcementMode::Audit.to_string(), "audit");
819        assert_eq!(EnforcementMode::Enforce.to_string(), "enforce");
820        assert_eq!(EnforcementMode::Adaptive.to_string(), "adaptive");
821    }
822
823    #[test]
824    fn test_enforcement_mode_equality() {
825        assert_eq!(EnforcementMode::Audit, EnforcementMode::Audit);
826        assert_ne!(EnforcementMode::Audit, EnforcementMode::Enforce);
827    }
828
829    // ---------------------------------------------------------------
830    // PolicyVerdict
831    // ---------------------------------------------------------------
832
833    #[test]
834    fn test_policy_verdict_display() {
835        assert_eq!(PolicyVerdict::Allow.to_string(), "allow");
836        assert_eq!(
837            PolicyVerdict::Deny("blocked".to_string()).to_string(),
838            "deny: blocked"
839        );
840        assert_eq!(
841            PolicyVerdict::Warn("caution".to_string()).to_string(),
842            "warn: caution"
843        );
844    }
845
846    // ---------------------------------------------------------------
847    // PolicyDecision
848    // ---------------------------------------------------------------
849
850    #[test]
851    fn test_policy_decision_allow() {
852        let decision = PolicyDecision::allow();
853        assert!(decision.is_allowed());
854        assert!(!decision.is_denied());
855        assert!(!decision.is_warned());
856        assert!(decision.findings.is_empty());
857    }
858
859    #[test]
860    fn test_policy_decision_deny() {
861        let decision = PolicyDecision::deny(
862            "blocked".to_string(),
863            vec![SecurityFinding::new(
864                SecuritySeverity::High,
865                "test".to_string(),
866                "test".to_string(),
867                0.9,
868            )],
869            "policy-1".to_string(),
870        );
871        assert!(decision.is_denied());
872        assert!(!decision.is_allowed());
873        assert!(!decision.is_warned());
874        assert_eq!(decision.findings.len(), 1);
875        assert_eq!(decision.policy_id, "policy-1");
876    }
877
878    #[test]
879    fn test_policy_decision_warn() {
880        let decision = PolicyDecision::warn("caution".to_string(), Vec::new(), "p1".to_string());
881        assert!(decision.is_warned());
882        assert!(!decision.is_allowed());
883        assert!(!decision.is_denied());
884    }
885
886    // ---------------------------------------------------------------
887    // ActionPolicy — constructors
888    // ---------------------------------------------------------------
889
890    #[test]
891    fn test_action_policy_new() {
892        let policy = ActionPolicy::new("test", "Test Policy");
893        assert_eq!(policy.id, "test");
894        assert_eq!(policy.name, "Test Policy");
895        assert_eq!(policy.mode, EnforcementMode::Enforce);
896        assert!(policy.allowed_tools.is_none());
897        assert!(policy.blocked_tools.is_empty());
898        assert!((policy.max_risk_score - 1.0).abs() < f64::EPSILON);
899        assert!(policy.allowed_action_types.is_none());
900        assert!(policy.max_actions_per_session.is_none());
901        assert!(policy.allow_unregistered);
902    }
903
904    #[test]
905    fn test_action_policy_permissive() {
906        let policy = ActionPolicy::permissive("perm", "Permissive");
907        assert_eq!(policy.mode, EnforcementMode::Audit);
908        assert!(policy.allowed_tools.is_none());
909        assert!(policy.allow_unregistered);
910        assert!((policy.max_risk_score - 1.0).abs() < f64::EPSILON);
911    }
912
913    #[test]
914    fn test_action_policy_restrictive() {
915        let policy = ActionPolicy::restrictive("strict", "Strict");
916        assert_eq!(policy.mode, EnforcementMode::Enforce);
917        assert!(policy.allowed_tools.is_some());
918        assert!(policy.allowed_tools.as_ref().unwrap().is_empty());
919        assert!(!policy.allow_unregistered);
920        assert!((policy.max_risk_score - 0.7).abs() < f64::EPSILON);
921    }
922
923    // ---------------------------------------------------------------
924    // ActionPolicy — builder
925    // ---------------------------------------------------------------
926
927    #[test]
928    fn test_action_policy_builder() {
929        let mut allowed = HashSet::new();
930        allowed.insert("web_search".to_string());
931        allowed.insert("file_read".to_string());
932
933        let mut blocked = HashSet::new();
934        blocked.insert("shell_exec".to_string());
935
936        let mut action_types = HashSet::new();
937        action_types.insert(AgentActionType::ToolCall);
938
939        let policy = ActionPolicy::new("custom", "Custom Policy")
940            .with_mode(EnforcementMode::Adaptive)
941            .with_allowed_tools(allowed.clone())
942            .with_blocked_tools(blocked.clone())
943            .with_max_risk_score(0.5)
944            .with_allowed_action_types(action_types.clone())
945            .with_max_actions_per_session(100)
946            .with_allow_unregistered(false);
947
948        assert_eq!(policy.mode, EnforcementMode::Adaptive);
949        assert_eq!(policy.allowed_tools.as_ref().unwrap().len(), 2);
950        assert!(policy.blocked_tools.contains("shell_exec"));
951        assert!((policy.max_risk_score - 0.5).abs() < f64::EPSILON);
952        assert!(policy
953            .allowed_action_types
954            .as_ref()
955            .unwrap()
956            .contains(&AgentActionType::ToolCall));
957        assert_eq!(policy.max_actions_per_session, Some(100));
958        assert!(!policy.allow_unregistered);
959    }
960
961    #[test]
962    fn test_action_policy_max_risk_score_clamped() {
963        let policy = ActionPolicy::new("t", "T").with_max_risk_score(1.5);
964        assert!((policy.max_risk_score - 1.0).abs() < f64::EPSILON);
965
966        let policy = ActionPolicy::new("t", "T").with_max_risk_score(-0.5);
967        assert!(policy.max_risk_score.abs() < f64::EPSILON);
968    }
969
970    // ---------------------------------------------------------------
971    // ActionPolicy — evaluate: allowlist
972    // ---------------------------------------------------------------
973
974    #[test]
975    fn test_evaluate_allowlist_permits_listed_tool() {
976        let mut allowed = HashSet::new();
977        allowed.insert("web_search".to_string());
978
979        let policy = ActionPolicy::new("p", "P").with_allowed_tools(allowed);
980        let action = AgentAction::new(AgentActionType::ToolCall, "web_search".to_string());
981        let decision = policy.evaluate(&action, None);
982        assert!(decision.is_allowed());
983    }
984
985    #[test]
986    fn test_evaluate_allowlist_blocks_unlisted_tool() {
987        let mut allowed = HashSet::new();
988        allowed.insert("web_search".to_string());
989
990        let policy = ActionPolicy::new("p", "P").with_allowed_tools(allowed);
991        let action = AgentAction::new(AgentActionType::ToolCall, "shell_exec".to_string());
992        let decision = policy.evaluate(&action, None);
993        assert!(decision.is_denied());
994        assert!(decision
995            .findings
996            .iter()
997            .any(|f| f.finding_type == "policy_tool_not_allowed"));
998    }
999
1000    #[test]
1001    fn test_evaluate_allowlist_skips_non_tool_actions() {
1002        let mut allowed = HashSet::new();
1003        allowed.insert("web_search".to_string());
1004
1005        let policy = ActionPolicy::new("p", "P").with_allowed_tools(allowed);
1006        let action = AgentAction::new(AgentActionType::CommandExecution, "ls -la".to_string());
1007        let decision = policy.evaluate(&action, None);
1008        // CommandExecution is not a tool-like action, allowlist should not apply
1009        assert!(decision.is_allowed());
1010    }
1011
1012    // ---------------------------------------------------------------
1013    // ActionPolicy — evaluate: blocklist
1014    // ---------------------------------------------------------------
1015
1016    #[test]
1017    fn test_evaluate_blocklist_denies_blocked_tool() {
1018        let mut blocked = HashSet::new();
1019        blocked.insert("dangerous_tool".to_string());
1020
1021        let policy = ActionPolicy::new("p", "P").with_blocked_tools(blocked);
1022        let action = AgentAction::new(AgentActionType::ToolCall, "dangerous_tool".to_string());
1023        let decision = policy.evaluate(&action, None);
1024        assert!(decision.is_denied());
1025        assert!(decision
1026            .findings
1027            .iter()
1028            .any(|f| f.finding_type == "policy_tool_blocked"));
1029    }
1030
1031    #[test]
1032    fn test_evaluate_blocklist_allows_non_blocked_tool() {
1033        let mut blocked = HashSet::new();
1034        blocked.insert("dangerous_tool".to_string());
1035
1036        let policy = ActionPolicy::new("p", "P").with_blocked_tools(blocked);
1037        let action = AgentAction::new(AgentActionType::ToolCall, "safe_tool".to_string());
1038        let decision = policy.evaluate(&action, None);
1039        assert!(decision.is_allowed());
1040    }
1041
1042    // ---------------------------------------------------------------
1043    // ActionPolicy — evaluate: unregistered tools
1044    // ---------------------------------------------------------------
1045
1046    #[test]
1047    fn test_evaluate_unregistered_blocked_when_not_allowed() {
1048        let policy = ActionPolicy::new("p", "P").with_allow_unregistered(false);
1049        let action = AgentAction::new(AgentActionType::ToolCall, "unknown".to_string());
1050        let decision = policy.evaluate(&action, None);
1051        assert!(decision.is_denied());
1052        assert!(decision
1053            .findings
1054            .iter()
1055            .any(|f| f.finding_type == "policy_unregistered_tool_blocked"));
1056    }
1057
1058    #[test]
1059    fn test_evaluate_unregistered_allowed_when_permitted() {
1060        let policy = ActionPolicy::new("p", "P").with_allow_unregistered(true);
1061        let action = AgentAction::new(AgentActionType::ToolCall, "unknown".to_string());
1062        let decision = policy.evaluate(&action, None);
1063        assert!(decision.is_allowed());
1064    }
1065
1066    #[test]
1067    fn test_evaluate_unregistered_check_skips_non_tool_actions() {
1068        let policy = ActionPolicy::new("p", "P").with_allow_unregistered(false);
1069        let action = AgentAction::new(AgentActionType::FileAccess, "/etc/passwd".to_string());
1070        let decision = policy.evaluate(&action, None);
1071        // FileAccess is not tool-like, unregistered check should not apply
1072        assert!(decision.is_allowed());
1073    }
1074
1075    // ---------------------------------------------------------------
1076    // ActionPolicy — evaluate: risk score
1077    // ---------------------------------------------------------------
1078
1079    #[test]
1080    fn test_evaluate_risk_score_blocks_high_risk() {
1081        let policy = ActionPolicy::new("p", "P").with_max_risk_score(0.5);
1082        let tool =
1083            ToolDefinition::new("risky", "Risky", ToolCategory::CodeExecution).with_risk_score(0.9);
1084        let action = AgentAction::new(AgentActionType::ToolCall, "risky".to_string());
1085        let decision = policy.evaluate(&action, Some(&tool));
1086        assert!(decision.is_denied());
1087        assert!(decision
1088            .findings
1089            .iter()
1090            .any(|f| f.finding_type == "policy_risk_score_exceeded"));
1091    }
1092
1093    #[test]
1094    fn test_evaluate_risk_score_allows_within_threshold() {
1095        let policy = ActionPolicy::new("p", "P").with_max_risk_score(0.5);
1096        let tool =
1097            ToolDefinition::new("safe", "Safe", ToolCategory::DataRetrieval).with_risk_score(0.3);
1098        let action = AgentAction::new(AgentActionType::ToolCall, "safe".to_string());
1099        let decision = policy.evaluate(&action, Some(&tool));
1100        assert!(decision.is_allowed());
1101    }
1102
1103    #[test]
1104    fn test_evaluate_risk_score_at_boundary() {
1105        let policy = ActionPolicy::new("p", "P").with_max_risk_score(0.5);
1106        let tool =
1107            ToolDefinition::new("border", "Border", ToolCategory::WebAccess).with_risk_score(0.5);
1108        let action = AgentAction::new(AgentActionType::ToolCall, "border".to_string());
1109        let decision = policy.evaluate(&action, Some(&tool));
1110        // Exactly at boundary should be allowed (not >)
1111        assert!(decision.is_allowed());
1112    }
1113
1114    // ---------------------------------------------------------------
1115    // ActionPolicy — evaluate: action types
1116    // ---------------------------------------------------------------
1117
1118    #[test]
1119    fn test_evaluate_action_type_allowed() {
1120        let mut types = HashSet::new();
1121        types.insert(AgentActionType::ToolCall);
1122        types.insert(AgentActionType::WebAccess);
1123
1124        let policy = ActionPolicy::new("p", "P").with_allowed_action_types(types);
1125        let action = AgentAction::new(AgentActionType::ToolCall, "search".to_string());
1126        let decision = policy.evaluate(&action, None);
1127        assert!(decision.is_allowed());
1128    }
1129
1130    #[test]
1131    fn test_evaluate_action_type_blocked() {
1132        let mut types = HashSet::new();
1133        types.insert(AgentActionType::ToolCall);
1134
1135        let policy = ActionPolicy::new("p", "P").with_allowed_action_types(types);
1136        let action = AgentAction::new(AgentActionType::CommandExecution, "rm -rf /".to_string());
1137        let decision = policy.evaluate(&action, None);
1138        assert!(decision.is_denied());
1139        assert!(decision
1140            .findings
1141            .iter()
1142            .any(|f| f.finding_type == "policy_action_type_blocked"));
1143    }
1144
1145    // ---------------------------------------------------------------
1146    // ActionPolicy — evaluate: enforcement modes
1147    // ---------------------------------------------------------------
1148
1149    #[test]
1150    fn test_evaluate_audit_mode_warns_instead_of_deny() {
1151        let mut blocked = HashSet::new();
1152        blocked.insert("bad_tool".to_string());
1153
1154        let policy = ActionPolicy::new("p", "P")
1155            .with_mode(EnforcementMode::Audit)
1156            .with_blocked_tools(blocked);
1157
1158        let action = AgentAction::new(AgentActionType::ToolCall, "bad_tool".to_string());
1159        let decision = policy.evaluate(&action, None);
1160        assert!(decision.is_warned());
1161        assert!(!decision.findings.is_empty());
1162    }
1163
1164    #[test]
1165    fn test_evaluate_adaptive_mode_denies_high_risk() {
1166        let mut blocked = HashSet::new();
1167        blocked.insert("bad_tool".to_string());
1168
1169        let policy = ActionPolicy::new("p", "P")
1170            .with_mode(EnforcementMode::Adaptive)
1171            .with_blocked_tools(blocked);
1172
1173        let action = AgentAction::new(AgentActionType::ToolCall, "bad_tool".to_string());
1174        let decision = policy.evaluate(&action, None);
1175        // The blocklist finding is High severity, so adaptive should deny
1176        assert!(decision.is_denied());
1177    }
1178
1179    #[test]
1180    fn test_evaluate_multiple_violations() {
1181        let mut blocked = HashSet::new();
1182        blocked.insert("shell_exec".to_string());
1183
1184        let policy = ActionPolicy::new("p", "P")
1185            .with_blocked_tools(blocked)
1186            .with_max_risk_score(0.5);
1187
1188        let tool = ToolDefinition::new("shell_exec", "Shell", ToolCategory::CodeExecution)
1189            .with_risk_score(0.9);
1190        let action = AgentAction::new(AgentActionType::ToolCall, "shell_exec".to_string());
1191        let decision = policy.evaluate(&action, Some(&tool));
1192        assert!(decision.is_denied());
1193        // Should have findings for both blocklist and risk score
1194        assert!(decision.findings.len() >= 2);
1195    }
1196
1197    // ---------------------------------------------------------------
1198    // Message
1199    // ---------------------------------------------------------------
1200
1201    #[test]
1202    fn test_message_new() {
1203        let msg = Message::new("user", "Hello!");
1204        assert_eq!(msg.role, "user");
1205        assert_eq!(msg.content, "Hello!");
1206    }
1207
1208    #[test]
1209    fn test_message_equality() {
1210        let a = Message::new("user", "hi");
1211        let b = Message::new("user", "hi");
1212        assert_eq!(a, b);
1213
1214        let c = Message::new("assistant", "hi");
1215        assert_ne!(a, c);
1216    }
1217
1218    // ---------------------------------------------------------------
1219    // ContextMinimizer — defaults
1220    // ---------------------------------------------------------------
1221
1222    #[test]
1223    fn test_context_minimizer_default() {
1224        let minimizer = ContextMinimizer::default();
1225        assert_eq!(minimizer.max_turns, 10);
1226        assert!(minimizer.strip_system_prompts);
1227        assert!(!minimizer.strip_prior_tool_results);
1228        assert_eq!(minimizer.max_context_chars, 50_000);
1229    }
1230
1231    // ---------------------------------------------------------------
1232    // ContextMinimizer — minimize_context
1233    // ---------------------------------------------------------------
1234
1235    #[test]
1236    fn test_minimize_strips_system_prompts() {
1237        let minimizer = ContextMinimizer::new(10, true, false, 50_000);
1238        let messages = vec![
1239            Message::new("system", "You are helpful."),
1240            Message::new("user", "Hello"),
1241            Message::new("assistant", "Hi there!"),
1242        ];
1243        let result = minimizer.minimize_context(&messages);
1244        assert!(!result.iter().any(|m| m.role == "system"));
1245        assert_eq!(result.len(), 2);
1246    }
1247
1248    #[test]
1249    fn test_minimize_keeps_system_prompts_when_disabled() {
1250        let minimizer = ContextMinimizer::new(10, false, false, 50_000);
1251        let messages = vec![
1252            Message::new("system", "You are helpful."),
1253            Message::new("user", "Hello"),
1254        ];
1255        let result = minimizer.minimize_context(&messages);
1256        assert!(result.iter().any(|m| m.role == "system"));
1257    }
1258
1259    #[test]
1260    fn test_minimize_strips_tool_results() {
1261        let minimizer = ContextMinimizer::new(10, false, true, 50_000);
1262        let messages = vec![
1263            Message::new("user", "Search for cats"),
1264            Message::new("tool", "{\"results\": [\"cat1\", \"cat2\"]}"),
1265            Message::new("assistant", "Here are the results."),
1266        ];
1267        let result = minimizer.minimize_context(&messages);
1268        assert!(!result.iter().any(|m| m.role == "tool"));
1269        assert_eq!(result.len(), 2);
1270    }
1271
1272    #[test]
1273    fn test_minimize_limits_turns() {
1274        let minimizer = ContextMinimizer::new(3, false, false, 50_000);
1275        let messages = vec![
1276            Message::new("user", "msg1"),
1277            Message::new("assistant", "resp1"),
1278            Message::new("user", "msg2"),
1279            Message::new("assistant", "resp2"),
1280            Message::new("user", "msg3"),
1281        ];
1282        let result = minimizer.minimize_context(&messages);
1283        assert_eq!(result.len(), 3);
1284        // Should keep the last 3 messages: msg2, resp2, msg3
1285        assert_eq!(result[0].content, "msg2");
1286        assert_eq!(result[1].content, "resp2");
1287        assert_eq!(result[2].content, "msg3");
1288        // Verify the first two were dropped
1289        assert!(!result.iter().any(|m| m.content == "msg1"));
1290        assert!(!result.iter().any(|m| m.content == "resp1"));
1291    }
1292
1293    #[test]
1294    fn test_minimize_truncates_to_max_chars() {
1295        let minimizer = ContextMinimizer::new(10, false, false, 20);
1296        let messages = vec![
1297            Message::new("user", "Hello World!"), // 12 chars
1298            Message::new("assistant", "This is a long response."), // 24 chars
1299        ];
1300        let result = minimizer.minimize_context(&messages);
1301        // First message is 12 chars, second is 24 chars, total would be 36 > 20
1302        assert!(result.len() <= 2);
1303        let total: usize = result.iter().map(|m| m.content.chars().count()).sum();
1304        assert!(total <= 20);
1305    }
1306
1307    #[test]
1308    fn test_minimize_text_strips_api_keys() {
1309        let minimizer = ContextMinimizer::default();
1310        let text = "Use api_key=sk-abc123xyz789 for access";
1311        let result = minimizer.minimize_text(text);
1312        assert!(result.contains("[REDACTED]"));
1313        assert!(!result.contains("sk-abc123xyz789"));
1314    }
1315
1316    #[test]
1317    fn test_minimize_text_strips_bearer_tokens() {
1318        let minimizer = ContextMinimizer::default();
1319        let text = "Authorization: Bearer eyJhbGciOiJIUzI1NiJ9.eyJ0ZXN0IjoxfQ.sig";
1320        let result = minimizer.minimize_text(text);
1321        assert!(result.contains("[REDACTED]"));
1322    }
1323
1324    #[test]
1325    fn test_minimize_text_strips_connection_strings() {
1326        let minimizer = ContextMinimizer::default();
1327        let text = "connect to postgres://user:pass@host:5432/db";
1328        let result = minimizer.minimize_text(text);
1329        assert!(result.contains("[REDACTED]"));
1330        assert!(!result.contains("postgres://"));
1331    }
1332
1333    #[test]
1334    fn test_minimize_with_custom_pattern() {
1335        let minimizer = ContextMinimizer::default().with_strip_pattern(r"(?i)SECRET_VALUE_\w+");
1336        let text = "The value is SECRET_VALUE_ABC123 here";
1337        let result = minimizer.minimize_text(text);
1338        assert!(result.contains("[REDACTED]"));
1339        assert!(!result.contains("SECRET_VALUE_ABC123"));
1340    }
1341
1342    #[test]
1343    fn test_minimize_empty_messages() {
1344        let minimizer = ContextMinimizer::default();
1345        let result = minimizer.minimize_context(&[]);
1346        assert!(result.is_empty());
1347    }
1348
1349    // ---------------------------------------------------------------
1350    // PolicyEngine — basic
1351    // ---------------------------------------------------------------
1352
1353    #[test]
1354    fn test_engine_new_no_policies() {
1355        let engine = PolicyEngine::new();
1356        assert_eq!(engine.policy_count(), 0);
1357    }
1358
1359    #[test]
1360    fn test_engine_add_policy() {
1361        let mut engine = PolicyEngine::new();
1362        engine.add_policy(ActionPolicy::new("p1", "Policy 1"));
1363        engine.add_policy(ActionPolicy::new("p2", "Policy 2"));
1364        assert_eq!(engine.policy_count(), 2);
1365    }
1366
1367    #[test]
1368    fn test_engine_allows_when_no_policies() {
1369        let engine = PolicyEngine::new();
1370        let action = AgentAction::new(AgentActionType::ToolCall, "any_tool".to_string());
1371        let decision = engine.evaluate_action(&action, None, "session-1");
1372        assert!(decision.is_allowed());
1373    }
1374
1375    // ---------------------------------------------------------------
1376    // PolicyEngine — evaluate
1377    // ---------------------------------------------------------------
1378
1379    #[test]
1380    fn test_engine_first_deny_wins() {
1381        let mut blocked1 = HashSet::new();
1382        blocked1.insert("tool_a".to_string());
1383
1384        let mut blocked2 = HashSet::new();
1385        blocked2.insert("tool_b".to_string());
1386
1387        let mut engine = PolicyEngine::new();
1388        engine.add_policy(ActionPolicy::new("p1", "P1").with_blocked_tools(blocked1));
1389        engine.add_policy(ActionPolicy::new("p2", "P2").with_blocked_tools(blocked2));
1390
1391        let action = AgentAction::new(AgentActionType::ToolCall, "tool_a".to_string());
1392        let decision = engine.evaluate_action(&action, None, "s1");
1393        assert!(decision.is_denied());
1394        assert_eq!(decision.policy_id, "p1");
1395    }
1396
1397    #[test]
1398    fn test_engine_warn_returned_when_no_deny() {
1399        let mut blocked = HashSet::new();
1400        blocked.insert("tool_a".to_string());
1401
1402        let mut engine = PolicyEngine::new();
1403        engine.add_policy(
1404            ActionPolicy::new("audit_p", "Audit Policy")
1405                .with_mode(EnforcementMode::Audit)
1406                .with_blocked_tools(blocked),
1407        );
1408
1409        let action = AgentAction::new(AgentActionType::ToolCall, "tool_a".to_string());
1410        let decision = engine.evaluate_action(&action, None, "s1");
1411        assert!(decision.is_warned());
1412    }
1413
1414    #[test]
1415    fn test_engine_allows_when_all_pass() {
1416        let mut allowed = HashSet::new();
1417        allowed.insert("web_search".to_string());
1418
1419        let mut engine = PolicyEngine::new();
1420        engine.add_policy(ActionPolicy::new("p1", "P1").with_allowed_tools(allowed));
1421
1422        let action = AgentAction::new(AgentActionType::ToolCall, "web_search".to_string());
1423        let decision = engine.evaluate_action(&action, None, "s1");
1424        assert!(decision.is_allowed());
1425    }
1426
1427    // ---------------------------------------------------------------
1428    // PolicyEngine — session counters
1429    // ---------------------------------------------------------------
1430
1431    #[test]
1432    fn test_engine_session_counter() {
1433        let mut engine = PolicyEngine::new();
1434        engine.add_policy(ActionPolicy::new("p1", "P1").with_max_actions_per_session(3));
1435
1436        engine.record_action("session-1");
1437        engine.record_action("session-1");
1438        engine.record_action("session-1");
1439
1440        let action = AgentAction::new(AgentActionType::ToolCall, "tool".to_string());
1441        let decision = engine.evaluate_action(&action, None, "session-1");
1442        assert!(decision.is_denied());
1443        assert!(decision
1444            .findings
1445            .iter()
1446            .any(|f| f.finding_type == "policy_session_limit_exceeded"));
1447    }
1448
1449    #[test]
1450    fn test_engine_session_counter_independent_sessions() {
1451        let mut engine = PolicyEngine::new();
1452        engine.add_policy(ActionPolicy::new("p1", "P1").with_max_actions_per_session(2));
1453
1454        engine.record_action("session-1");
1455        engine.record_action("session-1");
1456
1457        let action = AgentAction::new(AgentActionType::ToolCall, "tool".to_string());
1458
1459        // session-1 should be at limit
1460        let decision = engine.evaluate_action(&action, None, "session-1");
1461        assert!(decision.is_denied());
1462
1463        // session-2 should be fine
1464        let decision = engine.evaluate_action(&action, None, "session-2");
1465        assert!(decision.is_allowed());
1466    }
1467
1468    #[test]
1469    fn test_engine_reset_session() {
1470        let mut engine = PolicyEngine::new();
1471        engine.add_policy(ActionPolicy::new("p1", "P1").with_max_actions_per_session(2));
1472
1473        engine.record_action("session-1");
1474        engine.record_action("session-1");
1475
1476        let action = AgentAction::new(AgentActionType::ToolCall, "tool".to_string());
1477        let decision = engine.evaluate_action(&action, None, "session-1");
1478        assert!(decision.is_denied());
1479
1480        engine.reset_session("session-1");
1481        let decision = engine.evaluate_action(&action, None, "session-1");
1482        assert!(decision.is_allowed());
1483    }
1484
1485    #[test]
1486    fn test_engine_session_limit_audit_mode() {
1487        let mut engine = PolicyEngine::new();
1488        engine.add_policy(
1489            ActionPolicy::new("p1", "P1")
1490                .with_mode(EnforcementMode::Audit)
1491                .with_max_actions_per_session(1),
1492        );
1493
1494        engine.record_action("s1");
1495
1496        let action = AgentAction::new(AgentActionType::ToolCall, "tool".to_string());
1497        let decision = engine.evaluate_action(&action, None, "s1");
1498        // Audit mode should warn, not deny
1499        assert!(decision.is_warned());
1500    }
1501
1502    // ---------------------------------------------------------------
1503    // PolicyEngine — minimize_context
1504    // ---------------------------------------------------------------
1505
1506    #[test]
1507    fn test_engine_minimize_context() {
1508        let engine = PolicyEngine::new();
1509        let messages = vec![
1510            Message::new("system", "You are helpful."),
1511            Message::new("user", "Hello"),
1512            Message::new("assistant", "Hi!"),
1513        ];
1514        let result = engine.minimize_context(&messages);
1515        // Default minimizer strips system prompts
1516        assert!(!result.iter().any(|m| m.role == "system"));
1517    }
1518
1519    #[test]
1520    fn test_engine_with_custom_minimizer() {
1521        let minimizer = ContextMinimizer::new(2, false, false, 50_000);
1522        let engine = PolicyEngine::with_context_minimizer(minimizer);
1523
1524        let messages = vec![
1525            Message::new("user", "msg1"),
1526            Message::new("assistant", "resp1"),
1527            Message::new("user", "msg2"),
1528            Message::new("assistant", "resp2"),
1529        ];
1530        let result = engine.minimize_context(&messages);
1531        assert_eq!(result.len(), 2);
1532    }
1533
1534    // ---------------------------------------------------------------
1535    // PolicyEngine — debug
1536    // ---------------------------------------------------------------
1537
1538    #[test]
1539    fn test_engine_debug() {
1540        let engine = PolicyEngine::new();
1541        let debug = format!("{:?}", engine);
1542        assert!(debug.contains("PolicyEngine"));
1543        assert!(debug.contains("policy_count"));
1544    }
1545
1546    // ---------------------------------------------------------------
1547    // PolicyEngine — default
1548    // ---------------------------------------------------------------
1549
1550    #[test]
1551    fn test_engine_default() {
1552        let engine = PolicyEngine::default();
1553        assert_eq!(engine.policy_count(), 0);
1554    }
1555
1556    // ---------------------------------------------------------------
1557    // Integration: restrictive policy + tool definition
1558    // ---------------------------------------------------------------
1559
1560    #[test]
1561    fn test_integration_restrictive_with_registered_tool() {
1562        let mut allowed = HashSet::new();
1563        allowed.insert("web_search".to_string());
1564
1565        let policy = ActionPolicy::restrictive("strict", "Strict").with_allowed_tools(allowed);
1566
1567        let tool = ToolDefinition::new("web_search", "Web Search", ToolCategory::WebAccess)
1568            .with_risk_score(0.3);
1569
1570        let action = AgentAction::new(AgentActionType::ToolCall, "web_search".to_string());
1571        let decision = policy.evaluate(&action, Some(&tool));
1572        assert!(decision.is_allowed());
1573    }
1574
1575    #[test]
1576    fn test_integration_restrictive_blocks_unregistered() {
1577        let policy = ActionPolicy::restrictive("strict", "Strict");
1578        let action = AgentAction::new(AgentActionType::ToolCall, "unknown_tool".to_string());
1579        let decision = policy.evaluate(&action, None);
1580        assert!(decision.is_denied());
1581    }
1582
1583    #[test]
1584    fn test_integration_restrictive_blocks_high_risk() {
1585        let mut allowed = HashSet::new();
1586        allowed.insert("shell_exec".to_string());
1587
1588        let policy = ActionPolicy::restrictive("strict", "Strict").with_allowed_tools(allowed);
1589
1590        let tool = ToolDefinition::new("shell_exec", "Shell", ToolCategory::CodeExecution)
1591            .with_risk_score(0.9);
1592
1593        let action = AgentAction::new(AgentActionType::ToolCall, "shell_exec".to_string());
1594        let decision = policy.evaluate(&action, Some(&tool));
1595        // Tool is in allowlist but risk exceeds 0.7 threshold
1596        assert!(decision.is_denied());
1597        assert!(decision
1598            .findings
1599            .iter()
1600            .any(|f| f.finding_type == "policy_risk_score_exceeded"));
1601    }
1602
1603    // ---------------------------------------------------------------
1604    // Integration: engine with multiple policies
1605    // ---------------------------------------------------------------
1606
1607    #[test]
1608    fn test_integration_engine_multi_policy() {
1609        let mut blocked = HashSet::new();
1610        blocked.insert("dangerous_tool".to_string());
1611
1612        let mut engine = PolicyEngine::new();
1613        // First policy: audit-only blocklist
1614        engine.add_policy(
1615            ActionPolicy::permissive("audit", "Audit").with_blocked_tools(blocked.clone()),
1616        );
1617        // Second policy: enforce blocklist
1618        engine.add_policy(ActionPolicy::new("enforce", "Enforce").with_blocked_tools(blocked));
1619
1620        let action = AgentAction::new(AgentActionType::ToolCall, "dangerous_tool".to_string());
1621        let decision = engine.evaluate_action(&action, None, "s1");
1622        // Audit policy warns, enforce policy denies — deny wins
1623        assert!(decision.is_denied());
1624    }
1625
1626    #[test]
1627    fn test_integration_full_pipeline() {
1628        let mut allowed = HashSet::new();
1629        allowed.insert("web_search".to_string());
1630        allowed.insert("file_read".to_string());
1631
1632        let mut engine = PolicyEngine::new();
1633        engine.add_policy(
1634            ActionPolicy::new("prod", "Production")
1635                .with_allowed_tools(allowed)
1636                .with_max_risk_score(0.6)
1637                .with_max_actions_per_session(5)
1638                .with_allow_unregistered(false),
1639        );
1640
1641        let search_tool = ToolDefinition::new("web_search", "Search", ToolCategory::WebAccess)
1642            .with_risk_score(0.3);
1643
1644        // Allowed action
1645        let action = AgentAction::new(AgentActionType::ToolCall, "web_search".to_string());
1646        let decision = engine.evaluate_action(&action, Some(&search_tool), "session-1");
1647        assert!(decision.is_allowed());
1648        engine.record_action("session-1");
1649
1650        // Blocked action (not in allowlist)
1651        let action = AgentAction::new(AgentActionType::ToolCall, "shell_exec".to_string());
1652        let decision = engine.evaluate_action(&action, None, "session-1");
1653        assert!(decision.is_denied());
1654
1655        // Context minimization
1656        let messages = vec![
1657            Message::new("system", "Be helpful."),
1658            Message::new("user", "Search for cats"),
1659            Message::new("assistant", "Here are results."),
1660        ];
1661        let minimized = engine.minimize_context(&messages);
1662        assert!(!minimized.iter().any(|m| m.role == "system"));
1663    }
1664
1665    // ---------------------------------------------------------------
1666    // ContextMinimizer — edge cases
1667    // ---------------------------------------------------------------
1668
1669    #[test]
1670    fn test_minimize_single_message() {
1671        let minimizer = ContextMinimizer::default();
1672        let messages = vec![Message::new("user", "Hello")];
1673        let result = minimizer.minimize_context(&messages);
1674        assert_eq!(result.len(), 1);
1675        assert_eq!(result[0].content, "Hello");
1676    }
1677
1678    #[test]
1679    fn test_minimize_preserves_order() {
1680        let minimizer = ContextMinimizer::new(10, false, false, 50_000);
1681        let messages = vec![
1682            Message::new("user", "first"),
1683            Message::new("assistant", "second"),
1684            Message::new("user", "third"),
1685        ];
1686        let result = minimizer.minimize_context(&messages);
1687        assert_eq!(result[0].content, "first");
1688        assert_eq!(result[1].content, "second");
1689        assert_eq!(result[2].content, "third");
1690    }
1691
1692    #[test]
1693    fn test_minimize_zero_max_chars() {
1694        let minimizer = ContextMinimizer::new(10, false, false, 0);
1695        let messages = vec![Message::new("user", "Hello")];
1696        let result = minimizer.minimize_context(&messages);
1697        assert!(result.is_empty());
1698    }
1699
1700    #[test]
1701    fn test_minimize_text_no_patterns_match() {
1702        let minimizer = ContextMinimizer::default();
1703        let text = "Just a normal message with no secrets.";
1704        let result = minimizer.minimize_text(text);
1705        assert_eq!(result, text);
1706    }
1707}