use std::collections::HashMap;
use super::types::{AccessRequest, Effect, Operation, Refined, Resource, Verdict};
use crate::permission::pattern::Pattern;
#[derive(Debug, Default)]
pub struct RepeatCounter {
counts: HashMap<String, u32>,
}
impl RepeatCounter {
fn key(op: Operation, key: &str) -> String {
format!("{op:?}\x00{key}")
}
pub fn prior(&self, op: Operation, key: &str) -> u32 {
self.counts.get(&Self::key(op, key)).copied().unwrap_or(0)
}
pub fn record(&mut self, op: Operation, key: &str) {
*self.counts.entry(Self::key(op, key)).or_insert(0) += 1;
}
pub fn reset(&mut self, op: Operation, key: &str) {
self.counts.remove(&Self::key(op, key));
}
pub fn clear(&mut self) {
self.counts.clear();
}
}
#[derive(Debug, Clone)]
pub struct AllowEntry {
pub op: Operation,
pub pattern: Pattern,
pub original: String,
}
#[derive(Debug, Default)]
pub struct SessionAllowlist {
entries: Vec<AllowEntry>,
}
impl SessionAllowlist {
pub fn add(&mut self, op: Operation, original: &str, pattern: Pattern) {
if self
.entries
.iter()
.any(|e| e.op == op && e.original == original)
{
return; }
self.entries.push(AllowEntry {
op,
pattern,
original: original.to_string(),
});
}
pub fn allows(&self, op: Operation, key: &str) -> bool {
self.entries
.iter()
.any(|e| e.op == op && e.pattern.matches(key))
}
pub fn entries(&self) -> impl Iterator<Item = (Operation, &str)> {
self.entries.iter().map(|e| (e.op, e.original.as_str()))
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn remove_at(&mut self, idx: usize) -> Option<(Operation, String)> {
if idx >= self.entries.len() {
return None;
}
let e = self.entries.remove(idx);
Some((e.op, e.original))
}
pub fn remove(&mut self, op: Operation, original: &str) -> usize {
let before = self.entries.len();
self.entries
.retain(|e| !(e.op == op && e.original == original));
before - self.entries.len()
}
}
#[derive(Debug, Default)]
pub struct PolicyCtx {
pub repeat: RepeatCounter,
pub allowlist: SessionAllowlist,
pub prompt_deny: Vec<String>,
}
pub trait Decider: Send + Sync {
fn id(&self) -> &'static str;
fn applies_to(&self, op: Operation, resource: &Resource) -> bool;
fn decide(
&self,
req: &AccessRequest,
op: Operation,
resource: &Resource,
ctx: &PolicyCtx,
) -> Option<Verdict>;
}
pub trait Modifier: Send + Sync {
fn id(&self) -> &'static str;
fn applies_to(&self, op: Operation, resource: &Resource) -> bool;
fn refine(
&self,
req: &AccessRequest,
op: Operation,
resource: &Resource,
current: Effect,
ctx: &PolicyCtx,
) -> Refined;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn repeat_counter_counts_after_record() {
let mut c = RepeatCounter::default();
assert_eq!(c.prior(Operation::Execute, "cargo test"), 0);
c.record(Operation::Execute, "cargo test");
assert_eq!(c.prior(Operation::Execute, "cargo test"), 1);
c.record(Operation::Execute, "cargo test");
assert_eq!(c.prior(Operation::Execute, "cargo test"), 2);
assert_eq!(c.prior(Operation::Execute, "cargo build"), 0);
assert_eq!(c.prior(Operation::Read, "cargo test"), 0);
c.clear();
assert_eq!(c.prior(Operation::Execute, "cargo test"), 0);
}
#[test]
fn session_allowlist_op_scoped_match_and_dedup() {
let mut al = SessionAllowlist::default();
al.add(
Operation::Execute,
"cargo *",
Pattern::new_command("cargo *"),
);
al.add(
Operation::Execute,
"cargo *",
Pattern::new_command("cargo *"),
); assert_eq!(al.entries().count(), 1);
assert!(al.allows(Operation::Execute, "cargo test --bin dirge"));
assert!(!al.allows(Operation::Execute, "git status"));
al.add(Operation::Edit, "src/**", Pattern::new("src/**"));
assert!(al.allows(Operation::Edit, "src/main.rs"));
assert!(!al.allows(Operation::Read, "src/main.rs"));
}
}