oxify_authz/
engine.rs

1//! Authorization engine implementing the check API
2
3use crate::*;
4use moka::future::Cache;
5use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
6use sqlx::Row;
7use std::collections::{HashMap, HashSet};
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11
12/// The main authorization engine
13pub struct AuthzEngine {
14    pool: SqlitePool,
15    cache: Arc<Cache<String, bool>>,
16    namespace_configs: Arc<HashMap<String, NamespaceConfig>>,
17    /// Bloom filter for quick negative lookups
18    bloom_filter: Arc<AuthzBloomFilter>,
19    /// Track Bloom filter statistics
20    bloom_stats: Arc<BloomStatsTracker>,
21}
22
23/// Thread-safe tracker for Bloom filter statistics
24pub struct BloomStatsTracker {
25    definite_negatives: AtomicU64,
26    potential_positives: AtomicU64,
27    true_positives: AtomicU64,
28    false_positives: AtomicU64,
29}
30
31impl Default for BloomStatsTracker {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl BloomStatsTracker {
38    pub fn new() -> Self {
39        Self {
40            definite_negatives: AtomicU64::new(0),
41            potential_positives: AtomicU64::new(0),
42            true_positives: AtomicU64::new(0),
43            false_positives: AtomicU64::new(0),
44        }
45    }
46
47    pub fn record_definite_negative(&self) {
48        self.definite_negatives.fetch_add(1, Ordering::Relaxed);
49    }
50
51    pub fn record_potential_positive(&self) {
52        self.potential_positives.fetch_add(1, Ordering::Relaxed);
53    }
54
55    pub fn record_true_positive(&self) {
56        self.true_positives.fetch_add(1, Ordering::Relaxed);
57    }
58
59    pub fn record_false_positive(&self) {
60        self.false_positives.fetch_add(1, Ordering::Relaxed);
61    }
62
63    pub fn get_stats(&self) -> BloomStats {
64        BloomStats {
65            definite_negatives: self.definite_negatives.load(Ordering::Relaxed),
66            potential_positives: self.potential_positives.load(Ordering::Relaxed),
67            true_positives: self.true_positives.load(Ordering::Relaxed),
68            false_positives: self.false_positives.load(Ordering::Relaxed),
69        }
70    }
71}
72
73impl AuthzEngine {
74    /// Create a new authorization engine
75    pub async fn new(database_url: &str) -> Result<Self> {
76        let pool = SqlitePoolOptions::new()
77            .max_connections(20)
78            .acquire_timeout(Duration::from_secs(5))
79            .connect(database_url)
80            .await
81            .map_err(|e| AuthzError::DatabaseError(format!("Failed to connect: {}", e)))?;
82
83        // Cache for authorization checks (100k entries, 1 hour TTL)
84        let cache = Cache::builder()
85            .max_capacity(100_000)
86            .time_to_live(Duration::from_secs(3600))
87            .build();
88
89        // Load namespace configurations
90        let mut namespace_configs = HashMap::new();
91        namespace_configs.insert(
92            "document".to_string(),
93            NamespaceConfig::document_namespace(),
94        );
95        namespace_configs.insert("folder".to_string(), NamespaceConfig::folder_namespace());
96
97        // Initialize Bloom filter with 1M capacity and 1% false positive rate
98        let bloom_filter = Arc::new(AuthzBloomFilter::with_config(BloomConfig {
99            expected_items: 1_000_000,
100            false_positive_rate: 0.01,
101        }));
102
103        Ok(Self {
104            pool,
105            cache: Arc::new(cache),
106            namespace_configs: Arc::new(namespace_configs),
107            bloom_filter,
108            bloom_stats: Arc::new(BloomStatsTracker::new()),
109        })
110    }
111
112    /// Get the Bloom filter statistics
113    pub fn bloom_stats(&self) -> BloomStats {
114        self.bloom_stats.get_stats()
115    }
116
117    /// Get a reference to the Bloom filter
118    pub fn bloom_filter(&self) -> &AuthzBloomFilter {
119        &self.bloom_filter
120    }
121
122    /// Write a relation tuple
123    pub async fn write_tuple(&self, tuple: RelationTuple) -> Result<()> {
124        sqlx::query(
125            r#"
126            INSERT OR IGNORE INTO authz_relation_tuples
127                (namespace, object_id, relation, subject_type, subject_id, subject_relation)
128            VALUES (?, ?, ?, ?, ?, ?)
129            "#,
130        )
131        .bind(&tuple.namespace)
132        .bind(&tuple.object_id)
133        .bind(&tuple.relation)
134        .bind(match &tuple.subject {
135            Subject::User(_) => "user",
136            Subject::UserSet { .. } => "userset",
137        })
138        .bind(match &tuple.subject {
139            Subject::User(id) => id.clone(),
140            Subject::UserSet {
141                namespace,
142                object_id,
143                ..
144            } => format!("{}:{}", namespace, object_id),
145        })
146        .bind(match &tuple.subject {
147            Subject::User(_) => None,
148            Subject::UserSet { relation, .. } => Some(relation.clone()),
149        })
150        .execute(&self.pool)
151        .await
152        .map_err(|e| AuthzError::DatabaseError(format!("Failed to write tuple: {}", e)))?;
153
154        // Add to Bloom filter for quick negative lookups
155        self.bloom_filter.add_tuple(&tuple);
156
157        // Invalidate cache for this object
158        let cache_key = self.cache_key(&tuple.namespace, &tuple.object_id, &tuple.relation);
159        self.cache.invalidate(&cache_key).await;
160
161        Ok(())
162    }
163
164    /// Delete a relation tuple
165    pub async fn delete_tuple(&self, tuple: RelationTuple) -> Result<()> {
166        sqlx::query(
167            r#"
168            DELETE FROM authz_relation_tuples
169            WHERE namespace = ?
170              AND object_id = ?
171              AND relation = ?
172              AND subject_type = ?
173              AND subject_id = ?
174            "#,
175        )
176        .bind(&tuple.namespace)
177        .bind(&tuple.object_id)
178        .bind(&tuple.relation)
179        .bind(match &tuple.subject {
180            Subject::User(_) => "user",
181            Subject::UserSet { .. } => "userset",
182        })
183        .bind(match &tuple.subject {
184            Subject::User(id) => id.clone(),
185            Subject::UserSet {
186                namespace,
187                object_id,
188                ..
189            } => format!("{}:{}", namespace, object_id),
190        })
191        .execute(&self.pool)
192        .await
193        .map_err(|e| AuthzError::DatabaseError(format!("Failed to delete tuple: {}", e)))?;
194
195        // Invalidate cache
196        let cache_key = self.cache_key(&tuple.namespace, &tuple.object_id, &tuple.relation);
197        self.cache.invalidate(&cache_key).await;
198
199        Ok(())
200    }
201
202    /// Check if a subject has a relation to an object
203    pub async fn check(&self, request: CheckRequest) -> Result<CheckResponse> {
204        // Generate cache key
205        let cache_key = format!(
206            "check:{}:{}:{}:{}",
207            request.namespace, request.object_id, request.relation, request.subject
208        );
209
210        // Check cache first
211        if let Some(allowed) = self.cache.get(&cache_key).await {
212            return Ok(CheckResponse {
213                allowed,
214                cached: true,
215            });
216        }
217
218        // Perform recursive check
219        let allowed = self
220            .check_recursive(&request, 0, &mut HashSet::new())
221            .await?;
222
223        // Cache the result
224        self.cache.insert(cache_key, allowed).await;
225
226        Ok(CheckResponse {
227            allowed,
228            cached: false,
229        })
230    }
231
232    /// Recursive check implementation (depth-first search)
233    fn check_recursive<'a>(
234        &'a self,
235        request: &'a CheckRequest,
236        depth: usize,
237        visited: &'a mut HashSet<String>,
238    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<bool>> + Send + 'a>> {
239        Box::pin(async move {
240            // Prevent infinite recursion
241            if depth > 10 {
242                return Err(AuthzError::CycleDetected);
243            }
244
245            let visit_key = format!(
246                "{}:{}:{}",
247                request.namespace, request.object_id, request.relation
248            );
249            if visited.contains(&visit_key) {
250                return Ok(false); // Already visited, avoid cycle
251            }
252            visited.insert(visit_key);
253
254            // Direct check: Is there a direct tuple?
255            let direct = self.check_direct(request).await?;
256            if direct {
257                return Ok(true);
258            }
259
260            // Check inherited relations
261            if let Some(namespace_config) = self.namespace_configs.get(&request.namespace) {
262                if let Some(relation_config) = namespace_config
263                    .relations
264                    .iter()
265                    .find(|r| r.name == request.relation)
266                {
267                    // Check if subject has any inherited relation
268                    for inherited_relation in &relation_config.inherits_from {
269                        let inherited_request = CheckRequest {
270                            namespace: request.namespace.clone(),
271                            object_id: request.object_id.clone(),
272                            relation: inherited_relation.clone(),
273                            subject: request.subject.clone(),
274                            context: None,
275                        };
276
277                        if self
278                            .check_recursive(&inherited_request, depth + 1, visited)
279                            .await?
280                        {
281                            return Ok(true);
282                        }
283                    }
284                }
285            }
286
287            // Check userset expansion
288            if let Subject::User(user_id) = &request.subject {
289                // Find all usersets this user belongs to
290                let usersets = self.find_usersets_for_user(user_id).await?;
291
292                for userset in usersets {
293                    let userset_request = CheckRequest {
294                        namespace: request.namespace.clone(),
295                        object_id: request.object_id.clone(),
296                        relation: request.relation.clone(),
297                        subject: userset,
298                        context: None,
299                    };
300
301                    if self
302                        .check_recursive(&userset_request, depth + 1, visited)
303                        .await?
304                    {
305                        return Ok(true);
306                    }
307                }
308            }
309
310            Ok(false)
311        })
312    }
313
314    /// Check for a direct tuple match
315    async fn check_direct(&self, request: &CheckRequest) -> Result<bool> {
316        let row = sqlx::query(
317            r#"
318            SELECT COUNT(*) as count FROM authz_relation_tuples
319            WHERE namespace = ?
320              AND object_id = ?
321              AND relation = ?
322              AND subject_type = ?
323              AND subject_id = ?
324            "#,
325        )
326        .bind(&request.namespace)
327        .bind(&request.object_id)
328        .bind(&request.relation)
329        .bind(match &request.subject {
330            Subject::User(_) => "user",
331            Subject::UserSet { .. } => "userset",
332        })
333        .bind(match &request.subject {
334            Subject::User(id) => id.clone(),
335            Subject::UserSet {
336                namespace,
337                object_id,
338                ..
339            } => format!("{}:{}", namespace, object_id),
340        })
341        .fetch_one(&self.pool)
342        .await
343        .map_err(|e| AuthzError::DatabaseError(format!("Failed to check direct: {}", e)))?;
344
345        let count: i64 = row.try_get("count").unwrap_or(0);
346        Ok(count > 0)
347    }
348
349    /// Find all usersets a user belongs to
350    async fn find_usersets_for_user(&self, user_id: &str) -> Result<Vec<Subject>> {
351        let rows = sqlx::query(
352            r#"
353            SELECT namespace, object_id, relation
354            FROM authz_relation_tuples
355            WHERE subject_type = 'user'
356              AND subject_id = ?
357            "#,
358        )
359        .bind(user_id)
360        .fetch_all(&self.pool)
361        .await
362        .map_err(|e| AuthzError::DatabaseError(format!("Failed to find usersets: {}", e)))?;
363
364        let mut usersets = Vec::new();
365        for row in rows {
366            let namespace: String = row.get("namespace");
367            let object_id: String = row.get("object_id");
368            let relation: String = row.get("relation");
369
370            usersets.push(Subject::UserSet {
371                namespace,
372                object_id,
373                relation,
374            });
375        }
376
377        Ok(usersets)
378    }
379
380    /// Expand a relation to find all subjects
381    pub async fn expand(&self, request: ExpandRequest) -> Result<ExpandResponse> {
382        let rows = sqlx::query(
383            r#"
384            SELECT subject_type, subject_id, subject_relation
385            FROM authz_relation_tuples
386            WHERE namespace = ?
387              AND object_id = ?
388              AND relation = ?
389            "#,
390        )
391        .bind(&request.namespace)
392        .bind(&request.object_id)
393        .bind(&request.relation)
394        .fetch_all(&self.pool)
395        .await
396        .map_err(|e| AuthzError::DatabaseError(format!("Failed to expand: {}", e)))?;
397
398        let mut subjects = Vec::new();
399        for row in rows {
400            let subject_type: String = row.get("subject_type");
401            let subject_id: String = row.get("subject_id");
402
403            let subject = if subject_type == "user" {
404                Subject::User(subject_id)
405            } else {
406                let subject_relation: Option<String> = row.get("subject_relation");
407                let parts: Vec<&str> = subject_id.split(':').collect();
408                if let (2, Some(relation)) = (parts.len(), subject_relation) {
409                    Subject::UserSet {
410                        namespace: parts[0].to_string(),
411                        object_id: parts[1].to_string(),
412                        relation,
413                    }
414                } else {
415                    continue; // Skip invalid usersets
416                }
417            };
418
419            subjects.push(subject);
420        }
421
422        Ok(ExpandResponse { subjects })
423    }
424
425    /// Generate cache key
426    fn cache_key(&self, namespace: &str, object_id: &str, relation: &str) -> String {
427        format!("{}:{}:{}", namespace, object_id, relation)
428    }
429
430    /// Run database migrations
431    pub async fn migrate(&self) -> Result<()> {
432        sqlx::query(include_str!("../migrations/001_init.sql"))
433            .execute(&self.pool)
434            .await
435            .map_err(|e| AuthzError::DatabaseError(format!("Migration failed: {}", e)))?;
436
437        Ok(())
438    }
439
440    /// Batch check multiple authorization requests efficiently
441    ///
442    /// This method uses Bloom filter to skip definitely non-existent tuples
443    /// and PostgreSQL ANY() for efficient batch queries.
444    ///
445    /// Performance targets:
446    /// - 100 checks: <50ms total
447    /// - Bloom filter reduces DB queries by ~50%
448    pub async fn batch_check(&self, requests: &[CheckRequest]) -> Result<Vec<CheckResponse>> {
449        if requests.is_empty() {
450            return Ok(Vec::new());
451        }
452
453        let mut results = vec![None; requests.len()];
454
455        // Phase 1: Check cache first
456        let mut cache_misses = Vec::new();
457        for (idx, request) in requests.iter().enumerate() {
458            let cache_key = format!(
459                "check:{}:{}:{}:{}",
460                request.namespace, request.object_id, request.relation, request.subject
461            );
462
463            if let Some(allowed) = self.cache.get(&cache_key).await {
464                results[idx] = Some(CheckResponse {
465                    allowed,
466                    cached: true,
467                });
468            } else {
469                cache_misses.push((idx, request, cache_key));
470            }
471        }
472
473        if cache_misses.is_empty() {
474            return Ok(results.into_iter().map(|r| r.unwrap()).collect());
475        }
476
477        // Phase 2: Use Bloom filter to filter out definitely non-existent tuples
478        let mut bloom_positives = Vec::new();
479        for (idx, request, cache_key) in cache_misses {
480            if self.bloom_filter.might_contain(request) {
481                self.bloom_stats.record_potential_positive();
482                bloom_positives.push((idx, request, cache_key));
483            } else {
484                // Bloom filter says definitely not there
485                self.bloom_stats.record_definite_negative();
486                results[idx] = Some(CheckResponse {
487                    allowed: false,
488                    cached: false,
489                });
490            }
491        }
492
493        if bloom_positives.is_empty() {
494            return Ok(results.into_iter().map(|r| r.unwrap()).collect());
495        }
496
497        // Phase 3: Batch query PostgreSQL using ANY() for direct checks
498        let db_results = self.batch_check_direct(&bloom_positives).await?;
499
500        // Phase 4: For items not found directly, do recursive checks
501        for ((idx, request, cache_key), found) in bloom_positives.into_iter().zip(db_results) {
502            let allowed = if found {
503                self.bloom_stats.record_true_positive();
504                true
505            } else {
506                // Need to do recursive check for inherited relations
507                let recursive_result = self
508                    .check_recursive(request, 0, &mut HashSet::new())
509                    .await?;
510                if recursive_result {
511                    self.bloom_stats.record_true_positive();
512                } else {
513                    self.bloom_stats.record_false_positive();
514                }
515                recursive_result
516            };
517
518            // Cache the result
519            self.cache.insert(cache_key, allowed).await;
520
521            results[idx] = Some(CheckResponse {
522                allowed,
523                cached: false,
524            });
525        }
526
527        Ok(results.into_iter().map(|r| r.unwrap()).collect())
528    }
529
530    /// Batch check direct tuples (SQLite version uses individual queries)
531    async fn batch_check_direct(
532        &self,
533        requests: &[(usize, &CheckRequest, String)],
534    ) -> Result<Vec<bool>> {
535        if requests.is_empty() {
536            return Ok(Vec::new());
537        }
538
539        // For SQLite, we check each request individually
540        // This is less efficient than PostgreSQL's unnest, but SQLite doesn't support arrays
541        let mut results = Vec::with_capacity(requests.len());
542
543        for (_, request, _) in requests {
544            let subject_type = match &request.subject {
545                Subject::User(_) => "user",
546                Subject::UserSet { .. } => "userset",
547            };
548            let subject_id = match &request.subject {
549                Subject::User(id) => id.clone(),
550                Subject::UserSet {
551                    namespace,
552                    object_id,
553                    ..
554                } => format!("{}:{}", namespace, object_id),
555            };
556
557            let row = sqlx::query(
558                r#"
559                SELECT COUNT(*) as count FROM authz_relation_tuples
560                WHERE namespace = ?
561                  AND object_id = ?
562                  AND relation = ?
563                  AND subject_type = ?
564                  AND subject_id = ?
565                "#,
566            )
567            .bind(&request.namespace)
568            .bind(&request.object_id)
569            .bind(&request.relation)
570            .bind(subject_type)
571            .bind(&subject_id)
572            .fetch_one(&self.pool)
573            .await
574            .map_err(|e| AuthzError::DatabaseError(format!("Batch check failed: {}", e)))?;
575
576            let count: i64 = row.try_get("count").unwrap_or(0);
577            results.push(count > 0);
578        }
579
580        Ok(results)
581    }
582
583    /// Load existing tuples into the Bloom filter (for warm-up)
584    pub async fn warm_bloom_filter(&self) -> Result<usize> {
585        let rows = sqlx::query(
586            r#"
587            SELECT namespace, object_id, relation, subject_type, subject_id, subject_relation
588            FROM authz_relation_tuples
589            "#,
590        )
591        .fetch_all(&self.pool)
592        .await
593        .map_err(|e| AuthzError::DatabaseError(format!("Failed to load tuples: {}", e)))?;
594
595        let mut count = 0;
596        for row in rows {
597            let namespace: String = row.get("namespace");
598            let object_id: String = row.get("object_id");
599            let relation: String = row.get("relation");
600            let subject_type: String = row.get("subject_type");
601            let subject_id: String = row.get("subject_id");
602            let subject_relation: Option<String> = row.get("subject_relation");
603
604            let subject = if subject_type == "user" {
605                Subject::User(subject_id)
606            } else {
607                let parts: Vec<&str> = subject_id.split(':').collect();
608                if parts.len() == 2 {
609                    Subject::UserSet {
610                        namespace: parts[0].to_string(),
611                        object_id: parts[1].to_string(),
612                        relation: subject_relation.unwrap_or_default(),
613                    }
614                } else {
615                    continue;
616                }
617            };
618
619            let tuple = RelationTuple::new(&namespace, &relation, &object_id, subject);
620            self.bloom_filter.add_tuple(&tuple);
621            count += 1;
622        }
623
624        Ok(count)
625    }
626}
627
628#[cfg(test)]
629mod tests {
630    use super::*;
631
632    #[tokio::test]
633    #[ignore] // Requires database
634    async fn test_basic_authorization() {
635        let database_url =
636            std::env::var("DATABASE_URL").unwrap_or_else(|_| "sqlite::memory:".to_string());
637
638        let engine = AuthzEngine::new(&database_url).await.unwrap();
639        engine.migrate().await.unwrap();
640
641        // Write: alice owns document:123
642        engine
643            .write_tuple(RelationTuple::new(
644                "document",
645                "owner",
646                "123",
647                Subject::User("alice".to_string()),
648            ))
649            .await
650            .unwrap();
651
652        // Check: alice can view (owner inherits viewer)
653        let response = engine
654            .check(CheckRequest {
655                namespace: "document".to_string(),
656                object_id: "123".to_string(),
657                relation: "viewer".to_string(),
658                subject: Subject::User("alice".to_string()),
659                context: None,
660            })
661            .await
662            .unwrap();
663
664        assert!(response.allowed);
665    }
666}