use crate::error::NikaError;
use crate::runtime::boot::PolicyConfig;
use url::Url;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PolicyDecision {
Allow,
Block(String),
RequiresApproval(String),
}
impl PolicyDecision {
pub fn is_allowed(&self) -> bool {
matches!(self, Self::Allow)
}
pub fn is_blocked(&self) -> bool {
matches!(self, Self::Block(_))
}
}
#[derive(Debug, Clone, Default)]
pub struct TokenBudget {
pub limit: Option<u64>,
pub used: u64,
}
impl TokenBudget {
pub fn new(limit: Option<u64>) -> Self {
Self { limit, used: 0 }
}
pub fn can_spend(&self, tokens: u64) -> bool {
match self.limit {
Some(limit) => self.used + tokens <= limit,
None => true,
}
}
pub fn spend(&mut self, tokens: u64) {
self.used += tokens;
}
pub fn remaining(&self) -> Option<u64> {
self.limit.map(|l| l.saturating_sub(self.used))
}
}
#[derive(Debug, Clone)]
pub struct PolicyEnforcer {
config: PolicyConfig,
token_budget: TokenBudget,
}
impl Default for PolicyEnforcer {
fn default() -> Self {
Self::new(PolicyConfig::default())
}
}
impl PolicyEnforcer {
pub fn new(config: PolicyConfig) -> Self {
let token_budget = TokenBudget::new(config.max_token_spend);
Self {
config,
token_budget,
}
}
pub fn check_exec(&self, command: &str) -> PolicyDecision {
if !self.config.allow_exec {
return PolicyDecision::Block("exec: verb is disabled by policy".into());
}
let command_lower = command.to_lowercase();
for blocked in &self.config.blocked_commands {
if command_lower.contains(&blocked.to_lowercase()) {
return PolicyDecision::Block(format!(
"Command contains blocked pattern: '{}'",
blocked
));
}
}
PolicyDecision::Allow
}
pub fn check_fetch(&self, url: &str) -> PolicyDecision {
if !self.config.allow_network {
return PolicyDecision::Block(
"fetch: verb (network access) is disabled by policy".into(),
);
}
let parsed = match Url::parse(url) {
Ok(u) => u,
Err(_) => {
return PolicyDecision::Allow;
}
};
let host = match parsed.host_str() {
Some(h) => h.to_lowercase(),
None => return PolicyDecision::Allow,
};
for blocked in &self.config.blocked_hosts {
if host.contains(&blocked.to_lowercase()) || blocked.to_lowercase().contains(&host) {
return PolicyDecision::Block(format!("Host '{}' is blocked by policy", host));
}
}
if !self.config.allowed_hosts.is_empty() {
let is_allowed = self.config.allowed_hosts.iter().any(|allowed| {
host.contains(&allowed.to_lowercase()) || allowed.to_lowercase().contains(&host)
});
if !is_allowed {
return PolicyDecision::Block(format!(
"Host '{}' is not in allowed hosts list",
host
));
}
}
PolicyDecision::Allow
}
pub fn check_token_spend(&self, tokens: u64) -> PolicyDecision {
if !self.token_budget.can_spend(tokens) {
let remaining = self.token_budget.remaining().unwrap_or(0);
return PolicyDecision::Block(format!(
"Token budget exceeded: requested {} but only {} remaining",
tokens, remaining
));
}
PolicyDecision::Allow
}
pub fn record_token_spend(&mut self, tokens: u64) {
self.token_budget.spend(tokens);
}
pub fn remaining_budget(&self) -> Option<u64> {
self.token_budget.remaining()
}
pub fn tokens_used(&self) -> u64 {
self.token_budget.used
}
pub fn enforce(&self, decision: PolicyDecision) -> Result<(), NikaError> {
match decision {
PolicyDecision::Allow => Ok(()),
PolicyDecision::Block(reason) => Err(NikaError::PolicyViolation { reason }),
PolicyDecision::RequiresApproval(reason) => {
Err(NikaError::PolicyViolation {
reason: format!("Requires approval: {}", reason),
})
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_policy_allows_exec() {
let enforcer = PolicyEnforcer::default();
assert!(enforcer.check_exec("ls -la").is_allowed());
}
#[test]
fn test_policy_blocks_dangerous_commands() {
let enforcer = PolicyEnforcer::default();
assert!(enforcer.check_exec("sudo apt install").is_blocked());
assert!(enforcer.check_exec("rm -rf /").is_blocked());
assert!(enforcer.check_exec("chmod 777 /etc").is_blocked());
assert!(enforcer.check_exec("echo hello").is_allowed());
assert!(enforcer.check_exec("npm run build").is_allowed());
}
#[test]
fn test_policy_disables_exec() {
let config = PolicyConfig {
allow_exec: false,
..Default::default()
};
let enforcer = PolicyEnforcer::new(config);
assert!(enforcer.check_exec("echo hello").is_blocked());
}
#[test]
fn test_default_policy_allows_fetch() {
let enforcer = PolicyEnforcer::default();
assert!(enforcer
.check_fetch("https://api.example.com/data")
.is_allowed());
}
#[test]
fn test_policy_disables_network() {
let config = PolicyConfig {
allow_network: false,
..Default::default()
};
let enforcer = PolicyEnforcer::new(config);
assert!(enforcer.check_fetch("https://example.com").is_blocked());
}
#[test]
fn test_policy_blocks_hosts() {
let config = PolicyConfig {
blocked_hosts: vec!["evil.com".into(), "malware".into()],
..Default::default()
};
let enforcer = PolicyEnforcer::new(config);
assert!(enforcer.check_fetch("https://evil.com/path").is_blocked());
assert!(enforcer.check_fetch("https://malware.io/api").is_blocked());
assert!(enforcer.check_fetch("https://api.example.com").is_allowed());
}
#[test]
fn test_policy_allowed_hosts_whitelist() {
let config = PolicyConfig {
allowed_hosts: vec!["api.openai.com".into(), "anthropic.com".into()],
..Default::default()
};
let enforcer = PolicyEnforcer::new(config);
assert!(enforcer
.check_fetch("https://api.openai.com/v1")
.is_allowed());
assert!(enforcer
.check_fetch("https://anthropic.com/api")
.is_allowed());
assert!(enforcer.check_fetch("https://other.com/api").is_blocked());
}
#[test]
fn test_token_budget_unlimited() {
let budget = TokenBudget::new(None);
assert!(budget.can_spend(1_000_000));
assert!(budget.remaining().is_none());
}
#[test]
fn test_token_budget_limited() {
let mut budget = TokenBudget::new(Some(10000));
assert!(budget.can_spend(5000));
budget.spend(5000);
assert_eq!(budget.used, 5000);
assert_eq!(budget.remaining(), Some(5000));
assert!(budget.can_spend(5000));
assert!(!budget.can_spend(5001));
}
#[test]
fn test_enforcer_token_budget() {
let config = PolicyConfig {
max_token_spend: Some(1000),
..Default::default()
};
let mut enforcer = PolicyEnforcer::new(config);
assert!(enforcer.check_token_spend(500).is_allowed());
enforcer.record_token_spend(500);
assert!(enforcer.check_token_spend(500).is_allowed());
enforcer.record_token_spend(500);
assert!(enforcer.check_token_spend(1).is_blocked());
assert_eq!(enforcer.remaining_budget(), Some(0));
}
#[test]
fn test_policy_decision_properties() {
let allow = PolicyDecision::Allow;
let block = PolicyDecision::Block("reason".into());
assert!(allow.is_allowed());
assert!(!allow.is_blocked());
assert!(block.is_blocked());
assert!(!block.is_allowed());
}
}