Skip to main content

claw_guard/
guard.rs

1use std::path::Path;
2use std::sync::Arc;
3
4use chrono::Utc;
5use sqlx::{sqlite::SqlitePoolOptions, SqlitePool};
6use uuid::Uuid;
7
8use crate::audit::{write_direct, AuditEntry, AuditFilter, AuditReader, AuditWriter};
9use crate::config::GuardConfig;
10use crate::error::GuardResult;
11use crate::keys::ApiKeyManager;
12use crate::masking::MaskingEngine;
13use crate::policy::{EvalContext, PolicyEngine};
14use crate::session::SessionManager;
15use crate::types::{AccessResult, GuardSession, PolicyDecision};
16
17/// Main public entry point for ClawDB security checks.
18#[derive(Clone)]
19pub struct Guard {
20    pool: SqlitePool,
21    keys: ApiKeyManager,
22    sessions: SessionManager,
23    policy: PolicyEngine,
24    masking: MaskingEngine,
25    audit: AuditWriter,
26    audit_reader: AuditReader,
27    config: Arc<GuardConfig>,
28}
29
30impl Guard {
31    /// Opens the SQLite pool, applies migrations, and initializes guard services.
32    pub async fn new(config: GuardConfig) -> GuardResult<Self> {
33        let config = Arc::new(config);
34        let pool = SqlitePoolOptions::new()
35            .max_connections(5)
36            .connect(&config.sqlite_connection_string())
37            .await?;
38        sqlx::migrate!("./migrations").run(&pool).await?;
39
40        let policy = PolicyEngine::new(pool.clone(), config.clone());
41        if let Some(policy_dir) = &config.policy_dir {
42            policy.load_from_dir(policy_dir).await?;
43        }
44
45        let audit = AuditWriter::new(
46            pool.clone(),
47            tokio::time::Duration::from_millis(config.audit_flush_interval_ms),
48            config.audit_batch_size,
49        );
50
51        Ok(Self {
52            keys: ApiKeyManager::new(pool.clone()),
53            sessions: SessionManager::new(pool.clone(), config.clone()),
54            policy,
55            masking: MaskingEngine::new(),
56            audit,
57            audit_reader: AuditReader::new(pool.clone()),
58            pool,
59            config,
60        })
61    }
62
63    /// Returns the shared API key manager.
64    pub fn keys(&self) -> &ApiKeyManager {
65        &self.keys
66    }
67
68    /// Returns the shared session manager.
69    pub fn sessions(&self) -> &SessionManager {
70        &self.sessions
71    }
72
73    /// Returns the policy engine.
74    pub fn policy_engine(&self) -> &PolicyEngine {
75        &self.policy
76    }
77
78    /// Returns the configured masking engine.
79    pub fn masking_engine(&self) -> &MaskingEngine {
80        &self.masking
81    }
82
83    /// Returns the guard configuration.
84    pub fn config(&self) -> &GuardConfig {
85        self.config.as_ref()
86    }
87
88    /// Returns the underlying SQLite pool.
89    pub fn pool(&self) -> &SqlitePool {
90        &self.pool
91    }
92
93    /// Evaluates a session, action, and resource triplet.
94    pub async fn check_access(
95        &self,
96        session: &GuardSession,
97        action: &str,
98        resource: &str,
99    ) -> GuardResult<AccessResult> {
100        self.check_access_with_task(session, action, resource, action)
101            .await
102    }
103
104    /// Evaluates access with an explicit task name.
105    pub async fn check_access_with_task(
106        &self,
107        session: &GuardSession,
108        action: &str,
109        resource: &str,
110        task: &str,
111    ) -> GuardResult<AccessResult> {
112        let base_entry = AuditEntry {
113            id: Uuid::new_v4(),
114            session_id: Some(session.id),
115            workspace_id: session.workspace_id,
116            agent_id: Some(session.agent_id),
117            action: action.to_owned(),
118            resource: resource.to_owned(),
119            resource_id: None,
120            decision: String::new(),
121            reason: None,
122            risk_score: 0.0,
123            metadata: serde_json::json!({}),
124            ts: Utc::now(),
125        };
126
127        match self.sessions.assert_session_active(session).await {
128            Ok(()) => {
129                let mut ctx = EvalContext {
130                    agent_id: session.agent_id,
131                    workspace_id: session.workspace_id,
132                    role: session.role.clone(),
133                    scopes: session.scopes.clone(),
134                    task: task.to_owned(),
135                    resource: resource.to_owned(),
136                    action: action.to_owned(),
137                    risk_score: 0.0,
138                };
139                ctx.risk_score = self.policy.compute_risk(&ctx);
140                let decision = self.policy.evaluate(&ctx).await?;
141                self.write_audit(base_entry, &decision, ctx.risk_score)
142                    .await?;
143                Ok(decision)
144            }
145            Err(error) => {
146                let decision = PolicyDecision::Deny {
147                    reason: error.to_string(),
148                };
149                self.write_audit(base_entry, &decision, 1.0).await?;
150                Err(error)
151            }
152        }
153    }
154
155    /// Checks whether a session grants permission to use a tool.
156    pub fn check_tool_permission(&self, session: &GuardSession, tool_name: &str) -> bool {
157        session
158            .scopes
159            .iter()
160            .any(|scope| scope == "tool:*" || scope == &format!("tool:{tool_name}"))
161    }
162
163    /// Queries persisted audit records.
164    pub async fn query_audit(&self, filter: AuditFilter) -> GuardResult<Vec<AuditEntry>> {
165        self.audit_reader.query(filter).await
166    }
167
168    async fn write_audit(
169        &self,
170        mut entry: AuditEntry,
171        decision: &PolicyDecision,
172        risk_score: f64,
173    ) -> GuardResult<()> {
174        match decision {
175            PolicyDecision::Allow => {
176                entry.decision = "Allow".to_owned();
177            }
178            PolicyDecision::Deny { reason } => {
179                entry.decision = "Deny".to_owned();
180                entry.reason = Some(reason.clone());
181            }
182            PolicyDecision::Mask { fields } => {
183                entry.decision = "Mask".to_owned();
184                entry.metadata = serde_json::json!({ "fields": fields });
185            }
186        }
187        entry.risk_score = risk_score;
188
189        match self.audit.write(entry.clone()).await {
190            Ok(()) => Ok(()),
191            Err(_) => write_direct(&self.pool, &entry).await,
192        }
193    }
194}
195
196#[allow(dead_code)]
197fn _keep_path_used(_path: &Path) {}