Skip to main content

claw_guard/policy/
mod.rs

1use std::path::Path;
2use std::sync::Arc;
3
4use chrono::{DateTime, Local, Timelike, Utc};
5use glob::glob;
6use serde::{Deserialize, Serialize};
7use sqlx::{Row, SqlitePool};
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use crate::config::GuardConfig;
12use crate::error::{GuardError, GuardResult};
13use crate::masking::MaskType;
14use crate::types::PolicyDecision;
15
16/// A policy rule evaluated in order within a policy.
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub enum PolicyRule {
19    /// Allows the request if the condition matches.
20    AllowIf { condition: Condition },
21    /// Denies the request if the condition matches.
22    DenyIf {
23        condition: Condition,
24        reason: String,
25    },
26    /// Masks a matching field using the configured mask type.
27    MaskField {
28        field_pattern: String,
29        mask_type: MaskType,
30    },
31}
32
33/// Condition expression supported by the policy engine.
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
35pub enum Condition {
36    /// Matches an exact task name.
37    TaskMatches(String),
38    /// Matches an exact role name.
39    RoleIs(String),
40    /// Matches when a scope is present.
41    ScopeContains(String),
42    /// Matches when the risk score is above the threshold.
43    RiskAbove(f64),
44    /// Matches an exact resource name.
45    ResourceIs(String),
46    /// Matches an exact workspace id.
47    WorkspaceIs(Uuid),
48    /// Logical AND.
49    And(Box<Condition>, Box<Condition>),
50    /// Logical OR.
51    Or(Box<Condition>, Box<Condition>),
52    /// Logical NOT.
53    Not(Box<Condition>),
54}
55
56/// Context used for policy evaluation.
57#[derive(Debug, Clone, PartialEq)]
58pub struct EvalContext {
59    /// Agent identifier.
60    pub agent_id: Uuid,
61    /// Workspace identifier.
62    pub workspace_id: Uuid,
63    /// Role name.
64    pub role: String,
65    /// Granted scopes.
66    pub scopes: Vec<String>,
67    /// Task name.
68    pub task: String,
69    /// Resource name.
70    pub resource: String,
71    /// Action name.
72    pub action: String,
73    /// Risk score in the range `[0.0, 1.0]`.
74    pub risk_score: f64,
75}
76
77/// Persisted policy definition.
78#[derive(Debug, Clone, PartialEq)]
79pub struct Policy {
80    /// Policy identifier.
81    pub id: Uuid,
82    /// Unique policy name.
83    pub name: String,
84    /// Optional description.
85    pub description: Option<String>,
86    /// Ordered rules.
87    pub rules: Vec<PolicyRule>,
88    /// Higher values are evaluated first.
89    pub priority: i32,
90    /// Whether the policy is enabled.
91    pub enabled: bool,
92    /// Creation time.
93    pub created_at: DateTime<Utc>,
94    /// Last update time.
95    pub updated_at: DateTime<Utc>,
96}
97
98#[derive(Debug, Deserialize)]
99struct PolicyFileRoot {
100    #[serde(default)]
101    policies: Vec<PolicyFile>,
102}
103
104#[derive(Debug, Deserialize)]
105struct PolicyFile {
106    name: Option<String>,
107    description: Option<String>,
108    #[serde(default)]
109    priority: i32,
110    #[serde(default = "default_enabled")]
111    enabled: bool,
112    #[serde(default)]
113    rules: Vec<PolicyRuleFile>,
114}
115
116#[derive(Debug, Deserialize)]
117struct PolicyRuleFile {
118    #[serde(rename = "type")]
119    kind: String,
120    reason: Option<String>,
121    condition: Option<toml::Value>,
122    field_pattern: Option<String>,
123    mask_type: Option<toml::Value>,
124}
125
126/// Policy engine with a read-through cache backed by SQLite.
127#[derive(Clone)]
128pub struct PolicyEngine {
129    pool: SqlitePool,
130    config: Arc<GuardConfig>,
131    cache: Arc<RwLock<Vec<Policy>>>,
132}
133
134impl PolicyEngine {
135    /// Creates a new policy engine.
136    pub fn new(pool: SqlitePool, config: Arc<GuardConfig>) -> Self {
137        Self {
138            pool,
139            config,
140            cache: Arc::new(RwLock::new(Vec::new())),
141        }
142    }
143
144    /// Evaluates enabled policies in descending priority order.
145    pub async fn evaluate(&self, ctx: &EvalContext) -> GuardResult<PolicyDecision> {
146        let policies = self.cached_policies().await?;
147        for policy in policies {
148            for rule in &policy.rules {
149                match rule {
150                    PolicyRule::AllowIf { condition } if condition.matches(ctx) => {
151                        return Ok(PolicyDecision::Allow);
152                    }
153                    PolicyRule::DenyIf { condition, reason } if condition.matches(ctx) => {
154                        return Ok(PolicyDecision::Deny {
155                            reason: reason.clone(),
156                        });
157                    }
158                    PolicyRule::MaskField { field_pattern, .. } => {
159                        return Ok(PolicyDecision::Mask {
160                            fields: vec![field_pattern.clone()],
161                        });
162                    }
163                    _ => {}
164                }
165            }
166        }
167
168        Ok(PolicyDecision::Deny {
169            reason: "no matching policy rule".to_owned(),
170        })
171    }
172
173    /// Computes a risk score for the provided context.
174    pub fn compute_risk(&self, ctx: &EvalContext) -> f64 {
175        let mut score: f64 = 0.0;
176        let action = ctx.action.to_ascii_lowercase();
177        if action.contains("write") || action.contains("update") {
178            score += 0.3;
179        }
180        if action.contains("delete") {
181            score += 0.5;
182        }
183        if self
184            .config
185            .sensitive_resources
186            .iter()
187            .any(|resource| resource == &ctx.resource)
188        {
189            score += 0.3;
190        }
191        let hour = Local::now().hour() as u8;
192        if hour < self.config.business_hours_start_hour
193            || hour >= self.config.business_hours_end_hour
194        {
195            score += 0.1;
196        }
197        score.clamp(0.0, 1.0)
198    }
199
200    /// Loads all TOML policy files from a directory and upserts them into SQLite.
201    pub async fn load_from_dir(&self, dir: &Path) -> GuardResult<usize> {
202        if !dir.exists() {
203            return Ok(0);
204        }
205
206        let pattern = dir.join("*.toml").to_string_lossy().to_string();
207        let mut loaded = 0_usize;
208        for entry in glob(&pattern).map_err(|error| GuardError::ConfigError(error.to_string()))? {
209            let path = entry.map_err(|error| GuardError::ConfigError(error.to_string()))?;
210            let source = std::fs::read_to_string(&path)?;
211            self.add_policy_from_toml(
212                &source,
213                path.file_stem()
214                    .and_then(|value| value.to_str())
215                    .unwrap_or("policy"),
216            )
217            .await?;
218            loaded += 1;
219        }
220        Ok(loaded)
221    }
222
223    /// Parses one TOML source string and upserts the contained policies.
224    pub async fn add_policy_from_toml(
225        &self,
226        source: &str,
227        fallback_name: &str,
228    ) -> GuardResult<Policy> {
229        let policies = parse_policy_source(source, fallback_name)?;
230        let mut last = None;
231        for policy in policies {
232            self.upsert_policy(&policy).await?;
233            last = Some(policy);
234        }
235        last.ok_or_else(|| GuardError::ConfigError("policy file contained no policies".to_owned()))
236    }
237
238    /// Lists persisted policies in evaluation order.
239    pub async fn list_policies(&self) -> GuardResult<Vec<Policy>> {
240        self.load_policies_from_db().await
241    }
242
243    /// Removes a policy by id.
244    pub async fn remove_policy(&self, policy_id: Uuid) -> GuardResult<()> {
245        sqlx::query("DELETE FROM policies WHERE id = ?1")
246            .bind(policy_id.to_string())
247            .execute(&self.pool)
248            .await?;
249        self.invalidate_cache().await;
250        Ok(())
251    }
252
253    async fn upsert_policy(&self, policy: &Policy) -> GuardResult<()> {
254        sqlx::query(
255            "INSERT INTO policies (id, name, description, rules, priority, enabled, created_at, updated_at)
256             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
257             ON CONFLICT(name) DO UPDATE SET
258                description = excluded.description,
259                rules = excluded.rules,
260                priority = excluded.priority,
261                enabled = excluded.enabled,
262                updated_at = excluded.updated_at",
263        )
264        .bind(policy.id.to_string())
265        .bind(&policy.name)
266        .bind(&policy.description)
267        .bind(serde_json::to_string(&policy.rules)?)
268        .bind(policy.priority)
269        .bind(if policy.enabled { 1_i64 } else { 0_i64 })
270        .bind(policy.created_at.timestamp_millis())
271        .bind(policy.updated_at.timestamp_millis())
272        .execute(&self.pool)
273        .await?;
274        self.invalidate_cache().await;
275        Ok(())
276    }
277
278    async fn cached_policies(&self) -> GuardResult<Vec<Policy>> {
279        {
280            let cache = self.cache.read().await;
281            if !cache.is_empty() {
282                return Ok(cache.clone());
283            }
284        }
285
286        let policies = self.load_policies_from_db().await?;
287        let mut cache = self.cache.write().await;
288        *cache = policies.clone();
289        Ok(policies)
290    }
291
292    async fn load_policies_from_db(&self) -> GuardResult<Vec<Policy>> {
293        let rows = sqlx::query(
294            "SELECT id, name, description, rules, priority, enabled, created_at, updated_at
295             FROM policies WHERE enabled = 1 ORDER BY priority DESC, updated_at DESC",
296        )
297        .fetch_all(&self.pool)
298        .await?;
299
300        rows.iter().map(row_to_policy).collect()
301    }
302
303    async fn invalidate_cache(&self) {
304        self.cache.write().await.clear();
305    }
306}
307
308impl Condition {
309    fn matches(&self, ctx: &EvalContext) -> bool {
310        match self {
311            Self::TaskMatches(value) => ctx.task == *value,
312            Self::RoleIs(value) => ctx.role == *value,
313            Self::ScopeContains(value) => ctx.scopes.iter().any(|scope| scope == value),
314            Self::RiskAbove(value) => ctx.risk_score > *value,
315            Self::ResourceIs(value) => ctx.resource == *value,
316            Self::WorkspaceIs(value) => ctx.workspace_id == *value,
317            Self::And(left, right) => left.matches(ctx) && right.matches(ctx),
318            Self::Or(left, right) => left.matches(ctx) || right.matches(ctx),
319            Self::Not(inner) => !inner.matches(ctx),
320        }
321    }
322}
323
324fn parse_policy_source(source: &str, fallback_name: &str) -> GuardResult<Vec<Policy>> {
325    if source.contains("[[policies]]") {
326        let root: PolicyFileRoot = toml::from_str(source)?;
327        root.policies
328            .into_iter()
329            .enumerate()
330            .map(|(index, policy)| policy_from_file(policy, &format!("{fallback_name}-{index}")))
331            .collect()
332    } else {
333        let policy: PolicyFile = toml::from_str(source)?;
334        Ok(vec![policy_from_file(policy, fallback_name)?])
335    }
336}
337
338fn policy_from_file(policy: PolicyFile, fallback_name: &str) -> GuardResult<Policy> {
339    let now = Utc::now();
340    Ok(Policy {
341        id: Uuid::new_v4(),
342        name: policy.name.unwrap_or_else(|| fallback_name.to_owned()),
343        description: policy.description,
344        rules: policy
345            .rules
346            .into_iter()
347            .map(parse_rule)
348            .collect::<GuardResult<Vec<_>>>()?,
349        priority: policy.priority,
350        enabled: policy.enabled,
351        created_at: now,
352        updated_at: now,
353    })
354}
355
356fn parse_rule(rule: PolicyRuleFile) -> GuardResult<PolicyRule> {
357    match rule.kind.as_str() {
358        "allow_if" => Ok(PolicyRule::AllowIf {
359            condition: parse_condition(rule.condition.as_ref().ok_or_else(|| {
360                GuardError::ConfigError("allow_if requires condition".to_owned())
361            })?)?,
362        }),
363        "deny_if" => Ok(PolicyRule::DenyIf {
364            condition: parse_condition(rule.condition.as_ref().ok_or_else(|| {
365                GuardError::ConfigError("deny_if requires condition".to_owned())
366            })?)?,
367            reason: rule
368                .reason
369                .ok_or_else(|| GuardError::ConfigError("deny_if requires reason".to_owned()))?,
370        }),
371        "mask_field" => Ok(PolicyRule::MaskField {
372            field_pattern: rule.field_pattern.ok_or_else(|| {
373                GuardError::ConfigError("mask_field requires field_pattern".to_owned())
374            })?,
375            mask_type: parse_mask_type(rule.mask_type.as_ref().ok_or_else(|| {
376                GuardError::ConfigError("mask_field requires mask_type".to_owned())
377            })?)?,
378        }),
379        _ => Err(GuardError::ConfigError(format!(
380            "unsupported rule type: {}",
381            rule.kind
382        ))),
383    }
384}
385
386fn parse_condition(value: &toml::Value) -> GuardResult<Condition> {
387    let table = value
388        .as_table()
389        .ok_or_else(|| GuardError::ConfigError("condition must be a table".to_owned()))?;
390
391    if let Some(items) = table.get("and") {
392        return fold_conditions(items, Condition::And);
393    }
394    if let Some(items) = table.get("or") {
395        return fold_conditions(items, Condition::Or);
396    }
397    if let Some(item) = table.get("not") {
398        return Ok(Condition::Not(Box::new(parse_condition(item)?)));
399    }
400    if let Some(value) = table.get("task_matches").and_then(toml::Value::as_str) {
401        return Ok(Condition::TaskMatches(value.to_owned()));
402    }
403    if let Some(value) = table.get("role_is").and_then(toml::Value::as_str) {
404        return Ok(Condition::RoleIs(value.to_owned()));
405    }
406    if let Some(value) = table.get("scope_contains").and_then(toml::Value::as_str) {
407        return Ok(Condition::ScopeContains(value.to_owned()));
408    }
409    if let Some(value) = table.get("risk_above").and_then(toml::Value::as_float) {
410        return Ok(Condition::RiskAbove(value));
411    }
412    if let Some(value) = table.get("resource_is").and_then(toml::Value::as_str) {
413        return Ok(Condition::ResourceIs(value.to_owned()));
414    }
415    if let Some(value) = table.get("workspace_is").and_then(toml::Value::as_str) {
416        return Ok(Condition::WorkspaceIs(Uuid::parse_str(value)?));
417    }
418
419    Err(GuardError::ConfigError(
420        "unsupported condition shape".to_owned(),
421    ))
422}
423
424fn fold_conditions(
425    value: &toml::Value,
426    combine: fn(Box<Condition>, Box<Condition>) -> Condition,
427) -> GuardResult<Condition> {
428    let items = value
429        .as_array()
430        .ok_or_else(|| GuardError::ConfigError("logical conditions must be arrays".to_owned()))?;
431    let mut iter = items.iter().map(parse_condition);
432    let first = iter.next().transpose()?.ok_or_else(|| {
433        GuardError::ConfigError("logical conditions must not be empty".to_owned())
434    })?;
435
436    iter.try_fold(first, |left, right| {
437        Ok(combine(Box::new(left), Box::new(right?)))
438    })
439}
440
441fn parse_mask_type(value: &toml::Value) -> GuardResult<MaskType> {
442    if let Some(kind) = value.as_str() {
443        return match kind {
444            "redact" => Ok(MaskType::Redact),
445            "hash_blake3" | "hash" => Ok(MaskType::HashBlake3),
446            "email_mask" => Ok(MaskType::EmailMask),
447            _ => Err(GuardError::ConfigError(format!(
448                "unsupported mask type: {kind}"
449            ))),
450        };
451    }
452
453    let table = value
454        .as_table()
455        .ok_or_else(|| GuardError::ConfigError("mask_type must be a string or table".to_owned()))?;
456    if let Some(truncate) = table.get("truncate") {
457        if let Some(n) = truncate.get("n").and_then(toml::Value::as_integer) {
458            return Ok(MaskType::Truncate { n: n as usize });
459        }
460    }
461    if let Some(json_mask) = table.get("json_field_mask") {
462        if let Some(pattern) = json_mask.get("pattern").and_then(toml::Value::as_str) {
463            return Ok(MaskType::JsonFieldMask {
464                pattern: pattern.to_owned(),
465            });
466        }
467    }
468    Err(GuardError::ConfigError(
469        "unsupported mask type table".to_owned(),
470    ))
471}
472
473fn row_to_policy(row: &sqlx::sqlite::SqliteRow) -> GuardResult<Policy> {
474    Ok(Policy {
475        id: Uuid::parse_str(&row.try_get::<String, _>("id")?)?,
476        name: row.try_get("name")?,
477        description: row.try_get("description")?,
478        rules: serde_json::from_str(&row.try_get::<String, _>("rules")?)?,
479        priority: row.try_get("priority")?,
480        enabled: row.try_get::<i64, _>("enabled")? != 0,
481        created_at: from_ms(row.try_get("created_at")?)?,
482        updated_at: from_ms(row.try_get("updated_at")?)?,
483    })
484}
485
486fn from_ms(value: i64) -> GuardResult<DateTime<Utc>> {
487    DateTime::from_timestamp_millis(value)
488        .ok_or_else(|| GuardError::ConfigError(format!("invalid timestamp millis: {value}")))
489}
490
491fn default_enabled() -> bool {
492    true
493}