1use anyhow::{Context, Result};
6use starlark::environment::{FrozenModule, Globals, GlobalsBuilder, Module};
7use starlark::eval::Evaluator;
8use starlark::starlark_module;
9use starlark::syntax::{AstModule, Dialect};
10use starlark::values::none::NoneType;
11use std::path::{Path, PathBuf};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum PolicyDecision {
16 Allow,
18 Prompt(String),
20 Deny(String),
22}
23
24pub struct PolicyEngine {
26 policies: Vec<FrozenModule>,
28 policy_dir: PathBuf,
30}
31
32impl PolicyEngine {
33 pub fn new() -> Result<Self> {
35 let policy_dir = Self::default_policy_dir();
36 let mut engine = Self {
37 policies: Vec::new(),
38 policy_dir: policy_dir.clone(),
39 };
40
41 if policy_dir.exists() {
43 engine.load_policies()?;
44 } else {
45 log::info!(
46 "Policy directory {:?} does not exist, using defaults",
47 policy_dir
48 );
49 }
50
51 Ok(engine)
52 }
53
54 pub fn default_policy_dir() -> PathBuf {
56 std::env::var("HOME")
58 .map(PathBuf::from)
59 .unwrap_or_else(|_| PathBuf::from("."))
60 .join(".perspt")
61 .join("rules")
62 }
63
64 pub fn load_policies(&mut self) -> Result<()> {
66 if !self.policy_dir.exists() {
67 return Ok(());
68 }
69
70 for entry in std::fs::read_dir(&self.policy_dir)? {
71 let entry = entry?;
72 let path = entry.path();
73
74 if path.extension().is_some_and(|ext| ext == "star") {
75 match self.load_policy_file(&path) {
76 Ok(module) => {
77 self.policies.push(module);
78 log::info!("Loaded policy: {:?}", path);
79 }
80 Err(e) => {
81 log::warn!("Failed to load policy {:?}: {}", path, e);
82 }
83 }
84 }
85 }
86
87 log::info!("Loaded {} policies", self.policies.len());
88 Ok(())
89 }
90
91 fn load_policy_file(&self, path: &Path) -> Result<FrozenModule> {
93 let content = std::fs::read_to_string(path)
94 .context(format!("Failed to read policy file: {:?}", path))?;
95
96 let ast = AstModule::parse(path.to_string_lossy().as_ref(), content, &Dialect::Standard)
97 .map_err(|e| anyhow::anyhow!("Parse error: {}", e))?;
98
99 let globals = Self::create_globals();
100 let module = Module::new();
101
102 {
103 let mut eval = Evaluator::new(&module);
104 eval.eval_module(ast, &globals)
105 .map_err(|e| anyhow::anyhow!("Eval error: {}", e))?;
106 }
107
108 Ok(module.freeze()?)
109 }
110
111 fn create_globals() -> Globals {
113 #[starlark_module]
114 fn policy_builtins(builder: &mut GlobalsBuilder) {
115 fn matches_pattern(command: &str, pattern: &str) -> anyhow::Result<bool> {
117 Ok(command.contains(pattern))
118 }
119
120 fn log_policy(message: &str) -> anyhow::Result<NoneType> {
122 log::info!("[Policy] {}", message);
123 Ok(NoneType)
124 }
125 }
126
127 GlobalsBuilder::standard().with(policy_builtins).build()
128 }
129
130 pub fn evaluate(&self, command: &str) -> PolicyDecision {
132 if self.policies.is_empty() {
134 return self.default_policy(command);
135 }
136
137 self.default_policy(command)
140 }
141
142 fn default_policy(&self, command: &str) -> PolicyDecision {
144 let dangerous_patterns = ["rm -rf", "sudo", "chmod 777", "> /dev/", "mkfs", "dd if="];
146
147 for pattern in &dangerous_patterns {
148 if command.contains(pattern) {
149 return PolicyDecision::Deny(format!(
150 "Command contains dangerous pattern: {}",
151 pattern
152 ));
153 }
154 }
155
156 let network_patterns = ["curl", "wget", "nc ", "ssh ", "scp "];
158 for pattern in &network_patterns {
159 if command.contains(pattern) {
160 return PolicyDecision::Prompt(format!(
161 "Command requires network access: {}",
162 command
163 ));
164 }
165 }
166
167 if command.contains("git push") || command.contains("git force") {
169 return PolicyDecision::Prompt("Git push operation requires confirmation".to_string());
170 }
171
172 PolicyDecision::Allow
173 }
174
175 pub fn is_safe(&self, command: &str) -> bool {
177 matches!(self.evaluate(command), PolicyDecision::Allow)
178 }
179}
180
181impl Default for PolicyEngine {
182 fn default() -> Self {
183 Self::new().unwrap_or_else(|_| Self {
184 policies: Vec::new(),
185 policy_dir: PathBuf::from("."),
186 })
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_default_policy_allows_safe_commands() {
196 let engine = PolicyEngine::default();
197 assert!(matches!(
198 engine.evaluate("cargo build"),
199 PolicyDecision::Allow
200 ));
201 assert!(matches!(engine.evaluate("ls -la"), PolicyDecision::Allow));
202 }
203
204 #[test]
205 fn test_default_policy_denies_dangerous() {
206 let engine = PolicyEngine::default();
207 assert!(matches!(
208 engine.evaluate("rm -rf /"),
209 PolicyDecision::Deny(_)
210 ));
211 assert!(matches!(
212 engine.evaluate("sudo rm file"),
213 PolicyDecision::Deny(_)
214 ));
215 }
216
217 #[test]
218 fn test_default_policy_prompts_network() {
219 let engine = PolicyEngine::default();
220 assert!(matches!(
221 engine.evaluate("curl https://example.com"),
222 PolicyDecision::Prompt(_)
223 ));
224 }
225}