1use std::path::Path;
2use std::sync::Arc;
3
4use chrono::{DateTime, Local, Timelike, Utc};
5use glob::glob;
6use serde::{Deserialize, Serialize};
7use sqlx::{Row, SqlitePool};
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use crate::config::GuardConfig;
12use crate::error::{GuardError, GuardResult};
13use crate::masking::MaskType;
14use crate::types::PolicyDecision;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub enum PolicyRule {
19 AllowIf { condition: Condition },
21 DenyIf {
23 condition: Condition,
24 reason: String,
25 },
26 MaskField {
28 field_pattern: String,
29 mask_type: MaskType,
30 },
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
35pub enum Condition {
36 TaskMatches(String),
38 RoleIs(String),
40 ScopeContains(String),
42 RiskAbove(f64),
44 ResourceIs(String),
46 WorkspaceIs(Uuid),
48 And(Box<Condition>, Box<Condition>),
50 Or(Box<Condition>, Box<Condition>),
52 Not(Box<Condition>),
54}
55
56#[derive(Debug, Clone, PartialEq)]
58pub struct EvalContext {
59 pub agent_id: Uuid,
61 pub workspace_id: Uuid,
63 pub role: String,
65 pub scopes: Vec<String>,
67 pub task: String,
69 pub resource: String,
71 pub action: String,
73 pub risk_score: f64,
75}
76
77#[derive(Debug, Clone, PartialEq)]
79pub struct Policy {
80 pub id: Uuid,
82 pub name: String,
84 pub description: Option<String>,
86 pub rules: Vec<PolicyRule>,
88 pub priority: i32,
90 pub enabled: bool,
92 pub created_at: DateTime<Utc>,
94 pub updated_at: DateTime<Utc>,
96}
97
98#[derive(Debug, Deserialize)]
99struct PolicyFileRoot {
100 #[serde(default)]
101 policies: Vec<PolicyFile>,
102}
103
104#[derive(Debug, Deserialize)]
105struct PolicyFile {
106 name: Option<String>,
107 description: Option<String>,
108 #[serde(default)]
109 priority: i32,
110 #[serde(default = "default_enabled")]
111 enabled: bool,
112 #[serde(default)]
113 rules: Vec<PolicyRuleFile>,
114}
115
116#[derive(Debug, Deserialize)]
117struct PolicyRuleFile {
118 #[serde(rename = "type")]
119 kind: String,
120 reason: Option<String>,
121 condition: Option<toml::Value>,
122 field_pattern: Option<String>,
123 mask_type: Option<toml::Value>,
124}
125
126#[derive(Clone)]
128pub struct PolicyEngine {
129 pool: SqlitePool,
130 config: Arc<GuardConfig>,
131 cache: Arc<RwLock<Vec<Policy>>>,
132}
133
134impl PolicyEngine {
135 pub fn new(pool: SqlitePool, config: Arc<GuardConfig>) -> Self {
137 Self {
138 pool,
139 config,
140 cache: Arc::new(RwLock::new(Vec::new())),
141 }
142 }
143
144 pub async fn evaluate(&self, ctx: &EvalContext) -> GuardResult<PolicyDecision> {
146 let policies = self.cached_policies().await?;
147 for policy in policies {
148 for rule in &policy.rules {
149 match rule {
150 PolicyRule::AllowIf { condition } if condition.matches(ctx) => {
151 return Ok(PolicyDecision::Allow);
152 }
153 PolicyRule::DenyIf { condition, reason } if condition.matches(ctx) => {
154 return Ok(PolicyDecision::Deny {
155 reason: reason.clone(),
156 });
157 }
158 PolicyRule::MaskField { field_pattern, .. } => {
159 return Ok(PolicyDecision::Mask {
160 fields: vec![field_pattern.clone()],
161 });
162 }
163 _ => {}
164 }
165 }
166 }
167
168 Ok(PolicyDecision::Deny {
169 reason: "no matching policy rule".to_owned(),
170 })
171 }
172
173 pub fn compute_risk(&self, ctx: &EvalContext) -> f64 {
175 let mut score: f64 = 0.0;
176 let action = ctx.action.to_ascii_lowercase();
177 if action.contains("write") || action.contains("update") {
178 score += 0.3;
179 }
180 if action.contains("delete") {
181 score += 0.5;
182 }
183 if self
184 .config
185 .sensitive_resources
186 .iter()
187 .any(|resource| resource == &ctx.resource)
188 {
189 score += 0.3;
190 }
191 let hour = Local::now().hour() as u8;
192 if hour < self.config.business_hours_start_hour
193 || hour >= self.config.business_hours_end_hour
194 {
195 score += 0.1;
196 }
197 score.clamp(0.0, 1.0)
198 }
199
200 pub async fn load_from_dir(&self, dir: &Path) -> GuardResult<usize> {
202 if !dir.exists() {
203 return Ok(0);
204 }
205
206 let pattern = dir.join("*.toml").to_string_lossy().to_string();
207 let mut loaded = 0_usize;
208 for entry in glob(&pattern).map_err(|error| GuardError::ConfigError(error.to_string()))? {
209 let path = entry.map_err(|error| GuardError::ConfigError(error.to_string()))?;
210 let source = std::fs::read_to_string(&path)?;
211 self.add_policy_from_toml(
212 &source,
213 path.file_stem()
214 .and_then(|value| value.to_str())
215 .unwrap_or("policy"),
216 )
217 .await?;
218 loaded += 1;
219 }
220 Ok(loaded)
221 }
222
223 pub async fn add_policy_from_toml(
225 &self,
226 source: &str,
227 fallback_name: &str,
228 ) -> GuardResult<Policy> {
229 let policies = parse_policy_source(source, fallback_name)?;
230 let mut last = None;
231 for policy in policies {
232 self.upsert_policy(&policy).await?;
233 last = Some(policy);
234 }
235 last.ok_or_else(|| GuardError::ConfigError("policy file contained no policies".to_owned()))
236 }
237
238 pub async fn list_policies(&self) -> GuardResult<Vec<Policy>> {
240 self.load_policies_from_db().await
241 }
242
243 pub async fn remove_policy(&self, policy_id: Uuid) -> GuardResult<()> {
245 sqlx::query("DELETE FROM policies WHERE id = ?1")
246 .bind(policy_id.to_string())
247 .execute(&self.pool)
248 .await?;
249 self.invalidate_cache().await;
250 Ok(())
251 }
252
253 async fn upsert_policy(&self, policy: &Policy) -> GuardResult<()> {
254 sqlx::query(
255 "INSERT INTO policies (id, name, description, rules, priority, enabled, created_at, updated_at)
256 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
257 ON CONFLICT(name) DO UPDATE SET
258 description = excluded.description,
259 rules = excluded.rules,
260 priority = excluded.priority,
261 enabled = excluded.enabled,
262 updated_at = excluded.updated_at",
263 )
264 .bind(policy.id.to_string())
265 .bind(&policy.name)
266 .bind(&policy.description)
267 .bind(serde_json::to_string(&policy.rules)?)
268 .bind(policy.priority)
269 .bind(if policy.enabled { 1_i64 } else { 0_i64 })
270 .bind(policy.created_at.timestamp_millis())
271 .bind(policy.updated_at.timestamp_millis())
272 .execute(&self.pool)
273 .await?;
274 self.invalidate_cache().await;
275 Ok(())
276 }
277
278 async fn cached_policies(&self) -> GuardResult<Vec<Policy>> {
279 {
280 let cache = self.cache.read().await;
281 if !cache.is_empty() {
282 return Ok(cache.clone());
283 }
284 }
285
286 let policies = self.load_policies_from_db().await?;
287 let mut cache = self.cache.write().await;
288 *cache = policies.clone();
289 Ok(policies)
290 }
291
292 async fn load_policies_from_db(&self) -> GuardResult<Vec<Policy>> {
293 let rows = sqlx::query(
294 "SELECT id, name, description, rules, priority, enabled, created_at, updated_at
295 FROM policies WHERE enabled = 1 ORDER BY priority DESC, updated_at DESC",
296 )
297 .fetch_all(&self.pool)
298 .await?;
299
300 rows.iter().map(row_to_policy).collect()
301 }
302
303 async fn invalidate_cache(&self) {
304 self.cache.write().await.clear();
305 }
306}
307
308impl Condition {
309 fn matches(&self, ctx: &EvalContext) -> bool {
310 match self {
311 Self::TaskMatches(value) => ctx.task == *value,
312 Self::RoleIs(value) => ctx.role == *value,
313 Self::ScopeContains(value) => ctx.scopes.iter().any(|scope| scope == value),
314 Self::RiskAbove(value) => ctx.risk_score > *value,
315 Self::ResourceIs(value) => ctx.resource == *value,
316 Self::WorkspaceIs(value) => ctx.workspace_id == *value,
317 Self::And(left, right) => left.matches(ctx) && right.matches(ctx),
318 Self::Or(left, right) => left.matches(ctx) || right.matches(ctx),
319 Self::Not(inner) => !inner.matches(ctx),
320 }
321 }
322}
323
324fn parse_policy_source(source: &str, fallback_name: &str) -> GuardResult<Vec<Policy>> {
325 if source.contains("[[policies]]") {
326 let root: PolicyFileRoot = toml::from_str(source)?;
327 root.policies
328 .into_iter()
329 .enumerate()
330 .map(|(index, policy)| policy_from_file(policy, &format!("{fallback_name}-{index}")))
331 .collect()
332 } else {
333 let policy: PolicyFile = toml::from_str(source)?;
334 Ok(vec![policy_from_file(policy, fallback_name)?])
335 }
336}
337
338fn policy_from_file(policy: PolicyFile, fallback_name: &str) -> GuardResult<Policy> {
339 let now = Utc::now();
340 Ok(Policy {
341 id: Uuid::new_v4(),
342 name: policy.name.unwrap_or_else(|| fallback_name.to_owned()),
343 description: policy.description,
344 rules: policy
345 .rules
346 .into_iter()
347 .map(parse_rule)
348 .collect::<GuardResult<Vec<_>>>()?,
349 priority: policy.priority,
350 enabled: policy.enabled,
351 created_at: now,
352 updated_at: now,
353 })
354}
355
356fn parse_rule(rule: PolicyRuleFile) -> GuardResult<PolicyRule> {
357 match rule.kind.as_str() {
358 "allow_if" => Ok(PolicyRule::AllowIf {
359 condition: parse_condition(rule.condition.as_ref().ok_or_else(|| {
360 GuardError::ConfigError("allow_if requires condition".to_owned())
361 })?)?,
362 }),
363 "deny_if" => Ok(PolicyRule::DenyIf {
364 condition: parse_condition(rule.condition.as_ref().ok_or_else(|| {
365 GuardError::ConfigError("deny_if requires condition".to_owned())
366 })?)?,
367 reason: rule
368 .reason
369 .ok_or_else(|| GuardError::ConfigError("deny_if requires reason".to_owned()))?,
370 }),
371 "mask_field" => Ok(PolicyRule::MaskField {
372 field_pattern: rule.field_pattern.ok_or_else(|| {
373 GuardError::ConfigError("mask_field requires field_pattern".to_owned())
374 })?,
375 mask_type: parse_mask_type(rule.mask_type.as_ref().ok_or_else(|| {
376 GuardError::ConfigError("mask_field requires mask_type".to_owned())
377 })?)?,
378 }),
379 _ => Err(GuardError::ConfigError(format!(
380 "unsupported rule type: {}",
381 rule.kind
382 ))),
383 }
384}
385
386fn parse_condition(value: &toml::Value) -> GuardResult<Condition> {
387 let table = value
388 .as_table()
389 .ok_or_else(|| GuardError::ConfigError("condition must be a table".to_owned()))?;
390
391 if let Some(items) = table.get("and") {
392 return fold_conditions(items, Condition::And);
393 }
394 if let Some(items) = table.get("or") {
395 return fold_conditions(items, Condition::Or);
396 }
397 if let Some(item) = table.get("not") {
398 return Ok(Condition::Not(Box::new(parse_condition(item)?)));
399 }
400 if let Some(value) = table.get("task_matches").and_then(toml::Value::as_str) {
401 return Ok(Condition::TaskMatches(value.to_owned()));
402 }
403 if let Some(value) = table.get("role_is").and_then(toml::Value::as_str) {
404 return Ok(Condition::RoleIs(value.to_owned()));
405 }
406 if let Some(value) = table.get("scope_contains").and_then(toml::Value::as_str) {
407 return Ok(Condition::ScopeContains(value.to_owned()));
408 }
409 if let Some(value) = table.get("risk_above").and_then(toml::Value::as_float) {
410 return Ok(Condition::RiskAbove(value));
411 }
412 if let Some(value) = table.get("resource_is").and_then(toml::Value::as_str) {
413 return Ok(Condition::ResourceIs(value.to_owned()));
414 }
415 if let Some(value) = table.get("workspace_is").and_then(toml::Value::as_str) {
416 return Ok(Condition::WorkspaceIs(Uuid::parse_str(value)?));
417 }
418
419 Err(GuardError::ConfigError(
420 "unsupported condition shape".to_owned(),
421 ))
422}
423
424fn fold_conditions(
425 value: &toml::Value,
426 combine: fn(Box<Condition>, Box<Condition>) -> Condition,
427) -> GuardResult<Condition> {
428 let items = value
429 .as_array()
430 .ok_or_else(|| GuardError::ConfigError("logical conditions must be arrays".to_owned()))?;
431 let mut iter = items.iter().map(parse_condition);
432 let first = iter.next().transpose()?.ok_or_else(|| {
433 GuardError::ConfigError("logical conditions must not be empty".to_owned())
434 })?;
435
436 iter.try_fold(first, |left, right| {
437 Ok(combine(Box::new(left), Box::new(right?)))
438 })
439}
440
441fn parse_mask_type(value: &toml::Value) -> GuardResult<MaskType> {
442 if let Some(kind) = value.as_str() {
443 return match kind {
444 "redact" => Ok(MaskType::Redact),
445 "hash_blake3" | "hash" => Ok(MaskType::HashBlake3),
446 "email_mask" => Ok(MaskType::EmailMask),
447 _ => Err(GuardError::ConfigError(format!(
448 "unsupported mask type: {kind}"
449 ))),
450 };
451 }
452
453 let table = value
454 .as_table()
455 .ok_or_else(|| GuardError::ConfigError("mask_type must be a string or table".to_owned()))?;
456 if let Some(truncate) = table.get("truncate") {
457 if let Some(n) = truncate.get("n").and_then(toml::Value::as_integer) {
458 return Ok(MaskType::Truncate { n: n as usize });
459 }
460 }
461 if let Some(json_mask) = table.get("json_field_mask") {
462 if let Some(pattern) = json_mask.get("pattern").and_then(toml::Value::as_str) {
463 return Ok(MaskType::JsonFieldMask {
464 pattern: pattern.to_owned(),
465 });
466 }
467 }
468 Err(GuardError::ConfigError(
469 "unsupported mask type table".to_owned(),
470 ))
471}
472
473fn row_to_policy(row: &sqlx::sqlite::SqliteRow) -> GuardResult<Policy> {
474 Ok(Policy {
475 id: Uuid::parse_str(&row.try_get::<String, _>("id")?)?,
476 name: row.try_get("name")?,
477 description: row.try_get("description")?,
478 rules: serde_json::from_str(&row.try_get::<String, _>("rules")?)?,
479 priority: row.try_get("priority")?,
480 enabled: row.try_get::<i64, _>("enabled")? != 0,
481 created_at: from_ms(row.try_get("created_at")?)?,
482 updated_at: from_ms(row.try_get("updated_at")?)?,
483 })
484}
485
486fn from_ms(value: i64) -> GuardResult<DateTime<Utc>> {
487 DateTime::from_timestamp_millis(value)
488 .ok_or_else(|| GuardError::ConfigError(format!("invalid timestamp millis: {value}")))
489}
490
491fn default_enabled() -> bool {
492 true
493}