1use crate::*;
4use moka::future::Cache;
5use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
6use sqlx::Row;
7use std::collections::{HashMap, HashSet};
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11
12pub struct AuthzEngine {
14 pool: SqlitePool,
15 cache: Arc<Cache<String, bool>>,
16 namespace_configs: Arc<HashMap<String, NamespaceConfig>>,
17 bloom_filter: Arc<AuthzBloomFilter>,
19 bloom_stats: Arc<BloomStatsTracker>,
21}
22
23pub struct BloomStatsTracker {
25 definite_negatives: AtomicU64,
26 potential_positives: AtomicU64,
27 true_positives: AtomicU64,
28 false_positives: AtomicU64,
29}
30
31impl Default for BloomStatsTracker {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl BloomStatsTracker {
38 pub fn new() -> Self {
39 Self {
40 definite_negatives: AtomicU64::new(0),
41 potential_positives: AtomicU64::new(0),
42 true_positives: AtomicU64::new(0),
43 false_positives: AtomicU64::new(0),
44 }
45 }
46
47 pub fn record_definite_negative(&self) {
48 self.definite_negatives.fetch_add(1, Ordering::Relaxed);
49 }
50
51 pub fn record_potential_positive(&self) {
52 self.potential_positives.fetch_add(1, Ordering::Relaxed);
53 }
54
55 pub fn record_true_positive(&self) {
56 self.true_positives.fetch_add(1, Ordering::Relaxed);
57 }
58
59 pub fn record_false_positive(&self) {
60 self.false_positives.fetch_add(1, Ordering::Relaxed);
61 }
62
63 pub fn get_stats(&self) -> BloomStats {
64 BloomStats {
65 definite_negatives: self.definite_negatives.load(Ordering::Relaxed),
66 potential_positives: self.potential_positives.load(Ordering::Relaxed),
67 true_positives: self.true_positives.load(Ordering::Relaxed),
68 false_positives: self.false_positives.load(Ordering::Relaxed),
69 }
70 }
71}
72
73impl AuthzEngine {
74 pub async fn new(database_url: &str) -> Result<Self> {
76 let pool = SqlitePoolOptions::new()
77 .max_connections(20)
78 .acquire_timeout(Duration::from_secs(5))
79 .connect(database_url)
80 .await
81 .map_err(|e| AuthzError::DatabaseError(format!("Failed to connect: {}", e)))?;
82
83 let cache = Cache::builder()
85 .max_capacity(100_000)
86 .time_to_live(Duration::from_secs(3600))
87 .build();
88
89 let mut namespace_configs = HashMap::new();
91 namespace_configs.insert(
92 "document".to_string(),
93 NamespaceConfig::document_namespace(),
94 );
95 namespace_configs.insert("folder".to_string(), NamespaceConfig::folder_namespace());
96
97 let bloom_filter = Arc::new(AuthzBloomFilter::with_config(BloomConfig {
99 expected_items: 1_000_000,
100 false_positive_rate: 0.01,
101 }));
102
103 Ok(Self {
104 pool,
105 cache: Arc::new(cache),
106 namespace_configs: Arc::new(namespace_configs),
107 bloom_filter,
108 bloom_stats: Arc::new(BloomStatsTracker::new()),
109 })
110 }
111
112 pub fn bloom_stats(&self) -> BloomStats {
114 self.bloom_stats.get_stats()
115 }
116
117 pub fn bloom_filter(&self) -> &AuthzBloomFilter {
119 &self.bloom_filter
120 }
121
122 pub async fn write_tuple(&self, tuple: RelationTuple) -> Result<()> {
124 sqlx::query(
125 r#"
126 INSERT OR IGNORE INTO authz_relation_tuples
127 (namespace, object_id, relation, subject_type, subject_id, subject_relation)
128 VALUES (?, ?, ?, ?, ?, ?)
129 "#,
130 )
131 .bind(&tuple.namespace)
132 .bind(&tuple.object_id)
133 .bind(&tuple.relation)
134 .bind(match &tuple.subject {
135 Subject::User(_) => "user",
136 Subject::UserSet { .. } => "userset",
137 })
138 .bind(match &tuple.subject {
139 Subject::User(id) => id.clone(),
140 Subject::UserSet {
141 namespace,
142 object_id,
143 ..
144 } => format!("{}:{}", namespace, object_id),
145 })
146 .bind(match &tuple.subject {
147 Subject::User(_) => None,
148 Subject::UserSet { relation, .. } => Some(relation.clone()),
149 })
150 .execute(&self.pool)
151 .await
152 .map_err(|e| AuthzError::DatabaseError(format!("Failed to write tuple: {}", e)))?;
153
154 self.bloom_filter.add_tuple(&tuple);
156
157 let cache_key = self.cache_key(&tuple.namespace, &tuple.object_id, &tuple.relation);
159 self.cache.invalidate(&cache_key).await;
160
161 Ok(())
162 }
163
164 pub async fn delete_tuple(&self, tuple: RelationTuple) -> Result<()> {
166 sqlx::query(
167 r#"
168 DELETE FROM authz_relation_tuples
169 WHERE namespace = ?
170 AND object_id = ?
171 AND relation = ?
172 AND subject_type = ?
173 AND subject_id = ?
174 "#,
175 )
176 .bind(&tuple.namespace)
177 .bind(&tuple.object_id)
178 .bind(&tuple.relation)
179 .bind(match &tuple.subject {
180 Subject::User(_) => "user",
181 Subject::UserSet { .. } => "userset",
182 })
183 .bind(match &tuple.subject {
184 Subject::User(id) => id.clone(),
185 Subject::UserSet {
186 namespace,
187 object_id,
188 ..
189 } => format!("{}:{}", namespace, object_id),
190 })
191 .execute(&self.pool)
192 .await
193 .map_err(|e| AuthzError::DatabaseError(format!("Failed to delete tuple: {}", e)))?;
194
195 let cache_key = self.cache_key(&tuple.namespace, &tuple.object_id, &tuple.relation);
197 self.cache.invalidate(&cache_key).await;
198
199 Ok(())
200 }
201
202 pub async fn check(&self, request: CheckRequest) -> Result<CheckResponse> {
204 let cache_key = format!(
206 "check:{}:{}:{}:{}",
207 request.namespace, request.object_id, request.relation, request.subject
208 );
209
210 if let Some(allowed) = self.cache.get(&cache_key).await {
212 return Ok(CheckResponse {
213 allowed,
214 cached: true,
215 });
216 }
217
218 let allowed = self
220 .check_recursive(&request, 0, &mut HashSet::new())
221 .await?;
222
223 self.cache.insert(cache_key, allowed).await;
225
226 Ok(CheckResponse {
227 allowed,
228 cached: false,
229 })
230 }
231
232 fn check_recursive<'a>(
234 &'a self,
235 request: &'a CheckRequest,
236 depth: usize,
237 visited: &'a mut HashSet<String>,
238 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<bool>> + Send + 'a>> {
239 Box::pin(async move {
240 if depth > 10 {
242 return Err(AuthzError::CycleDetected);
243 }
244
245 let visit_key = format!(
246 "{}:{}:{}",
247 request.namespace, request.object_id, request.relation
248 );
249 if visited.contains(&visit_key) {
250 return Ok(false); }
252 visited.insert(visit_key);
253
254 let direct = self.check_direct(request).await?;
256 if direct {
257 return Ok(true);
258 }
259
260 if let Some(namespace_config) = self.namespace_configs.get(&request.namespace) {
262 if let Some(relation_config) = namespace_config
263 .relations
264 .iter()
265 .find(|r| r.name == request.relation)
266 {
267 for inherited_relation in &relation_config.inherits_from {
269 let inherited_request = CheckRequest {
270 namespace: request.namespace.clone(),
271 object_id: request.object_id.clone(),
272 relation: inherited_relation.clone(),
273 subject: request.subject.clone(),
274 context: None,
275 };
276
277 if self
278 .check_recursive(&inherited_request, depth + 1, visited)
279 .await?
280 {
281 return Ok(true);
282 }
283 }
284 }
285 }
286
287 if let Subject::User(user_id) = &request.subject {
289 let usersets = self.find_usersets_for_user(user_id).await?;
291
292 for userset in usersets {
293 let userset_request = CheckRequest {
294 namespace: request.namespace.clone(),
295 object_id: request.object_id.clone(),
296 relation: request.relation.clone(),
297 subject: userset,
298 context: None,
299 };
300
301 if self
302 .check_recursive(&userset_request, depth + 1, visited)
303 .await?
304 {
305 return Ok(true);
306 }
307 }
308 }
309
310 Ok(false)
311 })
312 }
313
314 async fn check_direct(&self, request: &CheckRequest) -> Result<bool> {
316 let row = sqlx::query(
317 r#"
318 SELECT COUNT(*) as count FROM authz_relation_tuples
319 WHERE namespace = ?
320 AND object_id = ?
321 AND relation = ?
322 AND subject_type = ?
323 AND subject_id = ?
324 "#,
325 )
326 .bind(&request.namespace)
327 .bind(&request.object_id)
328 .bind(&request.relation)
329 .bind(match &request.subject {
330 Subject::User(_) => "user",
331 Subject::UserSet { .. } => "userset",
332 })
333 .bind(match &request.subject {
334 Subject::User(id) => id.clone(),
335 Subject::UserSet {
336 namespace,
337 object_id,
338 ..
339 } => format!("{}:{}", namespace, object_id),
340 })
341 .fetch_one(&self.pool)
342 .await
343 .map_err(|e| AuthzError::DatabaseError(format!("Failed to check direct: {}", e)))?;
344
345 let count: i64 = row.try_get("count").unwrap_or(0);
346 Ok(count > 0)
347 }
348
349 async fn find_usersets_for_user(&self, user_id: &str) -> Result<Vec<Subject>> {
351 let rows = sqlx::query(
352 r#"
353 SELECT namespace, object_id, relation
354 FROM authz_relation_tuples
355 WHERE subject_type = 'user'
356 AND subject_id = ?
357 "#,
358 )
359 .bind(user_id)
360 .fetch_all(&self.pool)
361 .await
362 .map_err(|e| AuthzError::DatabaseError(format!("Failed to find usersets: {}", e)))?;
363
364 let mut usersets = Vec::new();
365 for row in rows {
366 let namespace: String = row.get("namespace");
367 let object_id: String = row.get("object_id");
368 let relation: String = row.get("relation");
369
370 usersets.push(Subject::UserSet {
371 namespace,
372 object_id,
373 relation,
374 });
375 }
376
377 Ok(usersets)
378 }
379
380 pub async fn expand(&self, request: ExpandRequest) -> Result<ExpandResponse> {
382 let rows = sqlx::query(
383 r#"
384 SELECT subject_type, subject_id, subject_relation
385 FROM authz_relation_tuples
386 WHERE namespace = ?
387 AND object_id = ?
388 AND relation = ?
389 "#,
390 )
391 .bind(&request.namespace)
392 .bind(&request.object_id)
393 .bind(&request.relation)
394 .fetch_all(&self.pool)
395 .await
396 .map_err(|e| AuthzError::DatabaseError(format!("Failed to expand: {}", e)))?;
397
398 let mut subjects = Vec::new();
399 for row in rows {
400 let subject_type: String = row.get("subject_type");
401 let subject_id: String = row.get("subject_id");
402
403 let subject = if subject_type == "user" {
404 Subject::User(subject_id)
405 } else {
406 let subject_relation: Option<String> = row.get("subject_relation");
407 let parts: Vec<&str> = subject_id.split(':').collect();
408 if let (2, Some(relation)) = (parts.len(), subject_relation) {
409 Subject::UserSet {
410 namespace: parts[0].to_string(),
411 object_id: parts[1].to_string(),
412 relation,
413 }
414 } else {
415 continue; }
417 };
418
419 subjects.push(subject);
420 }
421
422 Ok(ExpandResponse { subjects })
423 }
424
425 fn cache_key(&self, namespace: &str, object_id: &str, relation: &str) -> String {
427 format!("{}:{}:{}", namespace, object_id, relation)
428 }
429
430 pub async fn migrate(&self) -> Result<()> {
432 sqlx::query(include_str!("../migrations/001_init.sql"))
433 .execute(&self.pool)
434 .await
435 .map_err(|e| AuthzError::DatabaseError(format!("Migration failed: {}", e)))?;
436
437 Ok(())
438 }
439
440 pub async fn batch_check(&self, requests: &[CheckRequest]) -> Result<Vec<CheckResponse>> {
449 if requests.is_empty() {
450 return Ok(Vec::new());
451 }
452
453 let mut results = vec![None; requests.len()];
454
455 let mut cache_misses = Vec::new();
457 for (idx, request) in requests.iter().enumerate() {
458 let cache_key = format!(
459 "check:{}:{}:{}:{}",
460 request.namespace, request.object_id, request.relation, request.subject
461 );
462
463 if let Some(allowed) = self.cache.get(&cache_key).await {
464 results[idx] = Some(CheckResponse {
465 allowed,
466 cached: true,
467 });
468 } else {
469 cache_misses.push((idx, request, cache_key));
470 }
471 }
472
473 if cache_misses.is_empty() {
474 return Ok(results.into_iter().map(|r| r.unwrap()).collect());
475 }
476
477 let mut bloom_positives = Vec::new();
479 for (idx, request, cache_key) in cache_misses {
480 if self.bloom_filter.might_contain(request) {
481 self.bloom_stats.record_potential_positive();
482 bloom_positives.push((idx, request, cache_key));
483 } else {
484 self.bloom_stats.record_definite_negative();
486 results[idx] = Some(CheckResponse {
487 allowed: false,
488 cached: false,
489 });
490 }
491 }
492
493 if bloom_positives.is_empty() {
494 return Ok(results.into_iter().map(|r| r.unwrap()).collect());
495 }
496
497 let db_results = self.batch_check_direct(&bloom_positives).await?;
499
500 for ((idx, request, cache_key), found) in bloom_positives.into_iter().zip(db_results) {
502 let allowed = if found {
503 self.bloom_stats.record_true_positive();
504 true
505 } else {
506 let recursive_result = self
508 .check_recursive(request, 0, &mut HashSet::new())
509 .await?;
510 if recursive_result {
511 self.bloom_stats.record_true_positive();
512 } else {
513 self.bloom_stats.record_false_positive();
514 }
515 recursive_result
516 };
517
518 self.cache.insert(cache_key, allowed).await;
520
521 results[idx] = Some(CheckResponse {
522 allowed,
523 cached: false,
524 });
525 }
526
527 Ok(results.into_iter().map(|r| r.unwrap()).collect())
528 }
529
530 async fn batch_check_direct(
532 &self,
533 requests: &[(usize, &CheckRequest, String)],
534 ) -> Result<Vec<bool>> {
535 if requests.is_empty() {
536 return Ok(Vec::new());
537 }
538
539 let mut results = Vec::with_capacity(requests.len());
542
543 for (_, request, _) in requests {
544 let subject_type = match &request.subject {
545 Subject::User(_) => "user",
546 Subject::UserSet { .. } => "userset",
547 };
548 let subject_id = match &request.subject {
549 Subject::User(id) => id.clone(),
550 Subject::UserSet {
551 namespace,
552 object_id,
553 ..
554 } => format!("{}:{}", namespace, object_id),
555 };
556
557 let row = sqlx::query(
558 r#"
559 SELECT COUNT(*) as count FROM authz_relation_tuples
560 WHERE namespace = ?
561 AND object_id = ?
562 AND relation = ?
563 AND subject_type = ?
564 AND subject_id = ?
565 "#,
566 )
567 .bind(&request.namespace)
568 .bind(&request.object_id)
569 .bind(&request.relation)
570 .bind(subject_type)
571 .bind(&subject_id)
572 .fetch_one(&self.pool)
573 .await
574 .map_err(|e| AuthzError::DatabaseError(format!("Batch check failed: {}", e)))?;
575
576 let count: i64 = row.try_get("count").unwrap_or(0);
577 results.push(count > 0);
578 }
579
580 Ok(results)
581 }
582
583 pub async fn warm_bloom_filter(&self) -> Result<usize> {
585 let rows = sqlx::query(
586 r#"
587 SELECT namespace, object_id, relation, subject_type, subject_id, subject_relation
588 FROM authz_relation_tuples
589 "#,
590 )
591 .fetch_all(&self.pool)
592 .await
593 .map_err(|e| AuthzError::DatabaseError(format!("Failed to load tuples: {}", e)))?;
594
595 let mut count = 0;
596 for row in rows {
597 let namespace: String = row.get("namespace");
598 let object_id: String = row.get("object_id");
599 let relation: String = row.get("relation");
600 let subject_type: String = row.get("subject_type");
601 let subject_id: String = row.get("subject_id");
602 let subject_relation: Option<String> = row.get("subject_relation");
603
604 let subject = if subject_type == "user" {
605 Subject::User(subject_id)
606 } else {
607 let parts: Vec<&str> = subject_id.split(':').collect();
608 if parts.len() == 2 {
609 Subject::UserSet {
610 namespace: parts[0].to_string(),
611 object_id: parts[1].to_string(),
612 relation: subject_relation.unwrap_or_default(),
613 }
614 } else {
615 continue;
616 }
617 };
618
619 let tuple = RelationTuple::new(&namespace, &relation, &object_id, subject);
620 self.bloom_filter.add_tuple(&tuple);
621 count += 1;
622 }
623
624 Ok(count)
625 }
626}
627
628#[cfg(test)]
629mod tests {
630 use super::*;
631
632 #[tokio::test]
633 #[ignore] async fn test_basic_authorization() {
635 let database_url =
636 std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite::memory:".to_string());
637
638 let engine = AuthzEngine::new(&database_url).await.unwrap();
639 engine.migrate().await.unwrap();
640
641 engine
643 .write_tuple(RelationTuple::new(
644 "document",
645 "owner",
646 "123",
647 Subject::User("alice".to_string()),
648 ))
649 .await
650 .unwrap();
651
652 let response = engine
654 .check(CheckRequest {
655 namespace: "document".to_string(),
656 object_id: "123".to_string(),
657 relation: "viewer".to_string(),
658 subject: Subject::User("alice".to_string()),
659 context: None,
660 })
661 .await
662 .unwrap();
663
664 assert!(response.allowed);
665 }
666}