use std::collections::HashMap;
use std::sync::Arc;
use arc_swap::ArcSwap;
use crate::web::{Error, RequestContext};
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum Decision {
Permit,
Deny,
}
pub struct PolicyInput<'a> {
pub action: &'a str,
pub subject: &'a serde_json::Map<String, serde_json::Value>,
pub tenant: Option<&'a str>,
pub resource: serde_json::Value,
pub env: EnvAttributes,
}
pub struct EnvAttributes {
pub unix_now: u64,
pub route: &'static str,
pub method: String,
}
impl PolicyInput<'_> {
pub fn subject_str(&self, key: &str) -> Option<&str> {
self.subject.get(key).and_then(|v| v.as_str())
}
pub fn resource_i64(&self, key: &str) -> Option<i64> {
self.resource.get(key).and_then(|v| v.as_i64())
}
}
pub struct CompiledRule {
pub effect: Decision,
pub when: Arc<dyn Fn(&PolicyInput) -> bool + Send + Sync>,
}
pub struct PolicySet {
pub version: u64,
by_action: HashMap<String, Vec<CompiledRule>>,
}
impl PolicySet {
pub fn new(version: u64) -> Self {
Self {
version,
by_action: HashMap::new(),
}
}
pub fn rule(
mut self,
action: &str,
effect: Decision,
when: impl Fn(&PolicyInput) -> bool + Send + Sync + 'static,
) -> Self {
self.by_action
.entry(action.to_owned())
.or_default()
.push(CompiledRule {
effect,
when: Arc::new(when),
});
self
}
fn evaluate(&self, input: &PolicyInput) -> Decision {
let Some(rules) = self.by_action.get(input.action) else {
return Decision::Deny; };
for rule in rules {
if (rule.when)(input) {
return rule.effect;
}
}
Decision::Deny
}
}
pub trait PolicySource: Send + Sync + 'static {
fn fetch(&self) -> futures::future::BoxFuture<'_, Result<PolicySet, String>>;
}
pub struct PolicyEngine {
set: ArcSwap<PolicySet>,
}
impl PolicyEngine {
pub fn new(initial: PolicySet) -> Self {
Self {
set: ArcSwap::from_pointee(initial),
}
}
pub fn reload(&self, next: PolicySet) {
let current = self.set.load().version;
if next.version <= current {
tracing::warn!(
current,
offered = next.version,
"ignoring stale policy reload"
);
return;
}
tracing::info!(version = next.version, "policy set reloaded (live)");
self.set.store(Arc::new(next));
}
pub fn evaluate(&self, input: &PolicyInput) -> Decision {
self.set.load().evaluate(input)
}
pub fn version(&self) -> u64 {
self.set.load().version
}
}
pub fn check_policies(
ctx: &RequestContext,
actions: &[&'static str],
resource: serde_json::Value,
) -> Result<(), Error> {
let claims = ctx.claims().ok_or(Error::Unauthorized)?;
let engine = ctx.try_inject::<PolicyEngine>().ok_or(Error::Forbidden)?;
for action in actions {
let input = PolicyInput {
action,
subject: claims,
tenant: ctx.tenant().map(|t| t.id.as_str()),
resource: resource.clone(),
env: EnvAttributes {
unix_now: crate::auth::session::unix_now(),
route: ctx.route(),
method: ctx.method().to_string(),
},
};
if engine.evaluate(&input) != Decision::Permit {
metrics::counter!("policy_denials_total", "action" => *action).increment(1);
return Err(Error::Forbidden);
}
}
Ok(())
}