use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use sha2::{Digest, Sha256};
use thiserror::Error;
use uuid::Uuid;
use vigil_audit::{ApprovalTargetContext, EngineDegradedPayload, Ledger, Result as AuditResult};
use vigil_policy::{
DescriptorState, PolicyAction, PolicyContext, PolicyDecision, PolicyEngine, PolicyError,
};
use vigil_types::{ApprovalRequest, DecisionKind, DecisionRecord, EffectVector, ToolInvocation};
use crate::extract::{
BrowserActionExtractor, EffectExtractor, EmailExtractor, PathExtractor, SecretRefExtractor,
ShellExtractor, SqlExtractor, UrlExtractor,
};
use crate::preflight::{run_preflight, EngineStatusReport, PreflightError};
use crate::scorer::{DescriptorOracle, DescriptorStatus, RiskScorer};
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum FirewallError {
#[error("policy: {0}")]
Policy(#[from] PolicyError),
#[error("audit: {0}")]
Audit(#[from] vigil_audit::AuditError),
#[error(
"config: `allowed_scopes` must not reuse reserved key `allowed_hosts` \
(host allowlist is managed via `FirewallConfig::allowed_hosts`)"
)]
ReservedScopeKey,
#[error("preflight scan failed: {reason}")]
PreflightScanFailed {
reason: String,
},
}
#[derive(Debug, Clone)]
pub struct FirewallConfig {
pub project_roots: Vec<String>,
pub allowed_hosts: Vec<String>,
pub allowed_scopes: HashMap<String, Vec<String>>,
pub approval_ttl_secs: u64,
pub long_text_threshold: usize,
}
impl Default for FirewallConfig {
fn default() -> Self {
Self {
project_roots: Vec::new(),
allowed_hosts: Vec::new(),
allowed_scopes: HashMap::new(),
approval_ttl_secs: 300,
long_text_threshold: 100,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum OAuthScopeContext {
NonOauth,
Scopes(Vec<String>),
}
impl OAuthScopeContext {
fn into_policy_requested_scopes(self) -> Option<Vec<String>> {
match self {
OAuthScopeContext::NonOauth => None,
OAuthScopeContext::Scopes(s) => Some(s),
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum FirewallOutcome {
Allowed {
decision: DecisionRecord,
effects: EffectVector,
},
Denied {
decision: DecisionRecord,
effects: EffectVector,
},
Approve {
decision: DecisionRecord,
effects: EffectVector,
approval: ApprovalRequest,
},
}
impl FirewallOutcome {
pub fn decision_kind(&self) -> DecisionKind {
match self {
FirewallOutcome::Allowed { .. } => DecisionKind::Allow,
FirewallOutcome::Denied { .. } => DecisionKind::Deny,
FirewallOutcome::Approve { .. } => DecisionKind::Approve,
}
}
}
pub struct Firewall {
ledger: Arc<Ledger>,
policy: PolicyEngine,
scorer: RiskScorer,
extractors: Vec<Box<dyn EffectExtractor>>,
config: FirewallConfig,
scanner: Arc<dyn crate::preflight::PiiScanner>,
audit_persist_failures: Arc<crate::preflight::AuditPersistCounter>,
}
impl std::fmt::Debug for Firewall {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Firewall")
.field("policy_rule_count", &self.policy.len())
.field("extractor_count", &self.extractors.len())
.field("config", &self.config)
.field(
"audit_persist_failures",
&self
.audit_persist_failures
.load(std::sync::atomic::Ordering::Relaxed),
)
.finish()
}
}
impl Firewall {
pub fn new(ledger: Arc<Ledger>, policy: PolicyEngine, config: FirewallConfig) -> Self {
Self::with_scanner(
ledger,
policy,
config,
crate::preflight::default_scanner_arc(),
)
}
pub fn with_scanner(
ledger: Arc<Ledger>,
policy: PolicyEngine,
config: FirewallConfig,
scanner: Arc<dyn crate::preflight::PiiScanner>,
) -> Self {
let roots: Vec<PathBuf> = config.project_roots.iter().map(PathBuf::from).collect();
let scorer = RiskScorer::new(config.allowed_hosts.clone(), config.project_roots.clone());
let extractors: Vec<Box<dyn EffectExtractor>> = vec![
Box::new(PathExtractor::new(roots)),
Box::new(UrlExtractor),
Box::new(SqlExtractor),
Box::new(ShellExtractor),
Box::new(EmailExtractor),
Box::new(SecretRefExtractor),
Box::new(BrowserActionExtractor),
];
Self {
ledger,
policy,
scorer,
extractors,
config,
scanner,
audit_persist_failures: Arc::new(crate::preflight::AuditPersistCounter::new(0)),
}
}
pub fn audit_persist_failures(&self) -> u64 {
self.audit_persist_failures
.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn evaluate(
&self,
call: &ToolInvocation,
oracle: &dyn DescriptorOracle,
scope_ctx: OAuthScopeContext,
) -> Result<FirewallOutcome, FirewallError> {
const RESERVED_ALLOWLIST_KEYS: &[&str] = &["allowed_hosts"];
if self
.config
.allowed_scopes
.keys()
.any(|k| RESERVED_ALLOWLIST_KEYS.contains(&k.as_str()))
{
return Err(FirewallError::ReservedScopeKey);
}
let mut effects = EffectVector::default();
for ex in &self.extractors {
ex.extract(call, &mut effects);
}
dedup_effects(&mut effects);
let descriptor = oracle.status(&call.server_id, &call.tool_name, &call.descriptor_hash);
let (risk_score, score_reasons) = self.scorer.score(&effects, descriptor);
let preflight = run_preflight(
self.scanner.as_ref(),
&self.ledger,
&self.audit_persist_failures,
&call.session_id,
&call.args,
self.config.long_text_threshold,
)
.map_err(|e| match e {
PreflightError::ScanFailed { reason } => FirewallError::PreflightScanFailed { reason },
})?;
let base_risk = risk_score;
let pii_delta = preflight.risk_delta;
let risk_with_pii = (base_risk as u32).saturating_add(pii_delta).min(100) as u8;
#[allow(unreachable_patterns)]
let descriptor_state = match descriptor {
DescriptorStatus::ApprovedStable => DescriptorState::ApprovedStable,
DescriptorStatus::FirstSeen => DescriptorState::FirstSeen,
DescriptorStatus::Drifted => DescriptorState::Drifted,
_ => DescriptorState::FirstSeen,
};
let mut ctx = PolicyContext {
risk_score: risk_with_pii,
descriptor: descriptor_state,
requested_scopes: scope_ctx.into_policy_requested_scopes(),
pii_findings: preflight.pii_summary.clone(),
..Default::default()
};
ctx.roots
.insert("project_roots".into(), self.config.project_roots.clone());
ctx.allowlists
.insert("allowed_hosts".into(), self.config.allowed_hosts.clone());
for (k, v) in &self.config.allowed_scopes {
ctx.allowlists.insert(k.clone(), v.clone());
}
let pdec: PolicyDecision = self.policy.evaluate(&effects, &ctx)?;
let preflight_reason = format!(
"preflight: base_risk={} pii_delta={} final={} labels={}",
base_risk,
pii_delta,
risk_with_pii,
if preflight.pii_summary.is_empty() {
"(none)".to_string()
} else {
preflight.counts_csv()
}
);
let mut decision_reasons = merge_reasons(&score_reasons, &pdec.reasons);
decision_reasons.push(preflight_reason);
let degraded_status = match preflight.engine_status {
EngineStatusReport::DegradedTimeout | EngineStatusReport::DegradedError => {
let stable = preflight.engine_status.stable_code();
decision_reasons.push(format!("engine.status={stable}"));
Some(preflight.engine_status)
}
EngineStatusReport::Ok | EngineStatusReport::Unsupported => None,
};
let decision_id = Uuid::new_v4().to_string();
let decision = DecisionRecord {
decision_id: decision_id.clone(),
invocation_id: call.invocation_id.clone(),
decision: map_action(pdec.action),
risk_score: risk_with_pii,
reasons: decision_reasons,
policy_ids: pdec.policy_ids.clone(),
created_at: now_secs(),
};
let _ = self
.ledger
.record_decision(&call.session_id, &decision, &effects)?;
if let Some(status) = degraded_status {
let payload = EngineDegradedPayload {
engine_id: "firewall_preflight_scanner".to_string(),
status: status.stable_code().to_string(),
reason_code: status.stable_code().to_string(),
budget_ms: None,
elapsed_ms: None,
fail_closed_decision: "fall_back_hard_only".to_string(),
decision_id: decision_id.clone(),
};
if self
.ledger
.record_engine_degraded(&call.session_id, &payload)
.is_err()
{
self.audit_persist_failures
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
match pdec.action {
PolicyAction::Allow => Ok(FirewallOutcome::Allowed { decision, effects }),
PolicyAction::Deny => Ok(FirewallOutcome::Denied { decision, effects }),
PolicyAction::Approve => {
let (title, summary) = summarize(call, &effects, &decision);
let args_hash = compute_args_hash(&call.args)?;
let ctx = ApprovalTargetContext {
server_id: Some(&call.server_id),
tool_name: Some(&call.tool_name),
args_hash: Some(&args_hash),
};
let approval: AuditResult<ApprovalRequest> = self.ledger.create_approval(
&call.session_id,
&decision,
&effects,
&title,
&summary,
self.config.approval_ttl_secs,
ctx,
);
let approval = approval?;
Ok(FirewallOutcome::Approve {
decision,
effects,
approval,
})
}
_ => Ok(FirewallOutcome::Denied { decision, effects }),
}
}
}
fn dedup_effects(e: &mut EffectVector) {
let mut seen = std::collections::HashSet::new();
e.effects.retain(|k| seen.insert(*k));
e.paths_read.sort();
e.paths_read.dedup();
e.paths_write.sort();
e.paths_write.dedup();
e.network_hosts.sort();
e.network_hosts.dedup();
e.secret_refs.sort();
e.secret_refs.dedup();
e.recipients.sort();
e.recipients.dedup();
}
fn map_action(a: PolicyAction) -> DecisionKind {
match a {
PolicyAction::Allow => DecisionKind::Allow,
PolicyAction::Deny => DecisionKind::Deny,
PolicyAction::Approve => DecisionKind::Approve,
_ => DecisionKind::Deny,
}
}
fn merge_reasons(score: &[String], policy: &[String]) -> Vec<String> {
let mut out = Vec::with_capacity(score.len() + policy.len());
out.extend(score.iter().cloned());
out.extend(policy.iter().cloned());
out
}
fn summarize(
call: &ToolInvocation,
effects: &EffectVector,
dec: &DecisionRecord,
) -> (String, String) {
let title = format!("{} on {}", call.tool_name, call.server_id);
let mut parts = Vec::new();
parts.push(format!("risk {}/100", dec.risk_score));
if !effects.paths_write.is_empty() {
parts.push(format!("writes: {}", effects.paths_write.join(", ")));
}
if !effects.paths_read.is_empty() {
parts.push(format!("reads: {}", effects.paths_read.len()));
}
if !effects.network_hosts.is_empty() {
parts.push(format!("hosts: {}", effects.network_hosts.join(", ")));
}
if !effects.secret_refs.is_empty() {
parts.push(format!("secrets: {}", effects.secret_refs.join(", ")));
}
if !effects.recipients.is_empty() {
parts.push(format!("recipients: {}", effects.recipients.len()));
}
(title, parts.join(" | "))
}
fn now_secs() -> i64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
}
pub(crate) fn compute_args_hash(args: &serde_json::Value) -> Result<String, FirewallError> {
let bytes = serde_jcs::to_vec(args)
.map_err(|e| FirewallError::Audit(vigil_audit::AuditError::Json(e)))?;
let mut h = Sha256::new();
h.update(&bytes);
Ok(hex::encode(h.finalize()))
}