1use std::path::Path;
2use std::sync::Arc;
3
4use chrono::Utc;
5use sqlx::{sqlite::SqlitePoolOptions, SqlitePool};
6use uuid::Uuid;
7
8use crate::audit::{write_direct, AuditEntry, AuditFilter, AuditReader, AuditWriter};
9use crate::config::GuardConfig;
10use crate::error::GuardResult;
11use crate::keys::ApiKeyManager;
12use crate::masking::MaskingEngine;
13use crate::policy::{EvalContext, PolicyEngine};
14use crate::session::SessionManager;
15use crate::types::{AccessResult, GuardSession, PolicyDecision};
16
17#[derive(Clone)]
19pub struct Guard {
20 pool: SqlitePool,
21 keys: ApiKeyManager,
22 sessions: SessionManager,
23 policy: PolicyEngine,
24 masking: MaskingEngine,
25 audit: AuditWriter,
26 audit_reader: AuditReader,
27 config: Arc<GuardConfig>,
28}
29
30impl Guard {
31 pub async fn new(config: GuardConfig) -> GuardResult<Self> {
33 let config = Arc::new(config);
34 let pool = SqlitePoolOptions::new()
35 .max_connections(5)
36 .connect(&config.sqlite_connection_string())
37 .await?;
38 sqlx::migrate!("./migrations").run(&pool).await?;
39
40 let policy = PolicyEngine::new(pool.clone(), config.clone());
41 if let Some(policy_dir) = &config.policy_dir {
42 policy.load_from_dir(policy_dir).await?;
43 }
44
45 let audit = AuditWriter::new(
46 pool.clone(),
47 tokio::time::Duration::from_millis(config.audit_flush_interval_ms),
48 config.audit_batch_size,
49 );
50
51 Ok(Self {
52 keys: ApiKeyManager::new(pool.clone()),
53 sessions: SessionManager::new(pool.clone(), config.clone()),
54 policy,
55 masking: MaskingEngine::new(),
56 audit,
57 audit_reader: AuditReader::new(pool.clone()),
58 pool,
59 config,
60 })
61 }
62
63 pub fn keys(&self) -> &ApiKeyManager {
65 &self.keys
66 }
67
68 pub fn sessions(&self) -> &SessionManager {
70 &self.sessions
71 }
72
73 pub fn policy_engine(&self) -> &PolicyEngine {
75 &self.policy
76 }
77
78 pub fn masking_engine(&self) -> &MaskingEngine {
80 &self.masking
81 }
82
83 pub fn config(&self) -> &GuardConfig {
85 self.config.as_ref()
86 }
87
88 pub fn pool(&self) -> &SqlitePool {
90 &self.pool
91 }
92
93 pub async fn check_access(
95 &self,
96 session: &GuardSession,
97 action: &str,
98 resource: &str,
99 ) -> GuardResult<AccessResult> {
100 self.check_access_with_task(session, action, resource, action)
101 .await
102 }
103
104 pub async fn check_access_with_task(
106 &self,
107 session: &GuardSession,
108 action: &str,
109 resource: &str,
110 task: &str,
111 ) -> GuardResult<AccessResult> {
112 let base_entry = AuditEntry {
113 id: Uuid::new_v4(),
114 session_id: Some(session.id),
115 workspace_id: session.workspace_id,
116 agent_id: Some(session.agent_id),
117 action: action.to_owned(),
118 resource: resource.to_owned(),
119 resource_id: None,
120 decision: String::new(),
121 reason: None,
122 risk_score: 0.0,
123 metadata: serde_json::json!({}),
124 ts: Utc::now(),
125 };
126
127 match self.sessions.assert_session_active(session).await {
128 Ok(()) => {
129 let mut ctx = EvalContext {
130 agent_id: session.agent_id,
131 workspace_id: session.workspace_id,
132 role: session.role.clone(),
133 scopes: session.scopes.clone(),
134 task: task.to_owned(),
135 resource: resource.to_owned(),
136 action: action.to_owned(),
137 risk_score: 0.0,
138 };
139 ctx.risk_score = self.policy.compute_risk(&ctx);
140 let decision = self.policy.evaluate(&ctx).await?;
141 self.write_audit(base_entry, &decision, ctx.risk_score)
142 .await?;
143 Ok(decision)
144 }
145 Err(error) => {
146 let decision = PolicyDecision::Deny {
147 reason: error.to_string(),
148 };
149 self.write_audit(base_entry, &decision, 1.0).await?;
150 Err(error)
151 }
152 }
153 }
154
155 pub fn check_tool_permission(&self, session: &GuardSession, tool_name: &str) -> bool {
157 session
158 .scopes
159 .iter()
160 .any(|scope| scope == "tool:*" || scope == &format!("tool:{tool_name}"))
161 }
162
163 pub async fn query_audit(&self, filter: AuditFilter) -> GuardResult<Vec<AuditEntry>> {
165 self.audit_reader.query(filter).await
166 }
167
168 async fn write_audit(
169 &self,
170 mut entry: AuditEntry,
171 decision: &PolicyDecision,
172 risk_score: f64,
173 ) -> GuardResult<()> {
174 match decision {
175 PolicyDecision::Allow => {
176 entry.decision = "Allow".to_owned();
177 }
178 PolicyDecision::Deny { reason } => {
179 entry.decision = "Deny".to_owned();
180 entry.reason = Some(reason.clone());
181 }
182 PolicyDecision::Mask { fields } => {
183 entry.decision = "Mask".to_owned();
184 entry.metadata = serde_json::json!({ "fields": fields });
185 }
186 }
187 entry.risk_score = risk_score;
188
189 match self.audit.write(entry.clone()).await {
190 Ok(()) => Ok(()),
191 Err(_) => write_direct(&self.pool, &entry).await,
192 }
193 }
194}
195
196#[allow(dead_code)]
197fn _keep_path_used(_path: &Path) {}