Skip to main content

postgres_bulk_rebac/
postgres_bulk_rebac.rs

1//! PostgreSQL-backed batch ReBAC example.
2//!
3//! This models a list endpoint with two policies:
4//!
5//! - public posts are visible through an in-memory predicate policy
6//! - private posts are visible when a `viewer` relationship exists in PostgreSQL
7//!
8//! The policy stack stays in Gatehouse. PostgreSQL is only responsible for
9//! loading relationship facts, and `EvaluationSession` batches, deduplicates,
10//! caches, and expands those facts back into caller order.
11
12use async_trait::async_trait;
13use gatehouse::{
14    EvaluationSession, FactLoadError, FactLoadResult, FactSource, PermissionChecker, PolicyBuilder,
15    RebacPolicy, RelationshipQuery,
16};
17use std::fmt;
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20use tokio_postgres::{Client, NoTls, Statement};
21use uuid::Uuid;
22
23type RelationshipKey = RelationshipQuery<Uuid, Uuid, Relation>;
24
25#[derive(Clone)]
26struct User {
27    id: Uuid,
28}
29
30#[derive(Clone)]
31struct Post {
32    id: Uuid,
33    public: bool,
34}
35
36struct View;
37
38#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
39enum Relation {
40    Viewer,
41}
42
43impl Relation {
44    fn as_str(self) -> &'static str {
45        match self {
46            Self::Viewer => "viewer",
47        }
48    }
49}
50
51impl fmt::Display for Relation {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        f.write_str(self.as_str())
54    }
55}
56
57#[derive(Clone)]
58struct PgRelationshipSource {
59    client: Arc<Client>,
60    point_stmt: Arc<Statement>,
61    bulk_stmt: Arc<Statement>,
62}
63
64impl PgRelationshipSource {
65    async fn load_point(&self, key: &RelationshipKey) -> FactLoadResult<bool> {
66        let relationship = key.relation.as_str();
67        match self
68            .client
69            .query_one(
70                &*self.point_stmt,
71                &[&key.subject_id, &relationship, &key.resource_id],
72            )
73            .await
74        {
75            Ok(row) => FactLoadResult::Found(row.get("allowed")),
76            Err(error) => FactLoadResult::Error(FactLoadError::backend(error)),
77        }
78    }
79
80    async fn load_bulk(&self, keys: &[RelationshipKey]) -> Vec<FactLoadResult<bool>> {
81        let subject_ids = keys.iter().map(|key| key.subject_id).collect::<Vec<_>>();
82        let post_ids = keys.iter().map(|key| key.resource_id).collect::<Vec<_>>();
83        let relationships = keys
84            .iter()
85            .map(|key| key.relation.as_str())
86            .collect::<Vec<_>>();
87
88        match self
89            .client
90            .query(&*self.bulk_stmt, &[&subject_ids, &post_ids, &relationships])
91            .await
92        {
93            Ok(rows) => rows
94                .into_iter()
95                .map(|row| FactLoadResult::Found(row.get("allowed")))
96                .collect(),
97            Err(error) => {
98                let error = FactLoadError::backend(error);
99                keys.iter()
100                    .map(|_| FactLoadResult::Error(error.clone()))
101                    .collect()
102            }
103        }
104    }
105}
106
107#[async_trait]
108impl FactSource<RelationshipKey> for PgRelationshipSource {
109    async fn load_many(&self, keys: &[RelationshipKey]) -> Vec<FactLoadResult<bool>> {
110        if keys.len() == 1 {
111            return vec![self.load_point(&keys[0]).await];
112        }
113
114        self.load_bulk(keys).await
115    }
116}
117
118async fn assert_point_and_bulk_agree(source: &PgRelationshipSource, keys: &[RelationshipKey]) {
119    for key in keys {
120        let point = source.load_point(key).await;
121        let bulk = source
122            .load_bulk(std::slice::from_ref(key))
123            .await
124            .into_iter()
125            .next()
126            .expect("bulk load for one key should return one result");
127
128        match (point, bulk) {
129            (FactLoadResult::Found(point), FactLoadResult::Found(bulk)) => {
130                assert_eq!(point, bulk, "point and bulk SQL should agree for {key:?}");
131            }
132            (point, bulk) => {
133                panic!(
134                    "point and bulk SQL should both succeed in the example: {point:?} vs {bulk:?}"
135                );
136            }
137        }
138    }
139}
140
141fn build_checker() -> PermissionChecker<User, Post, View, ()> {
142    let public_posts = PolicyBuilder::<User, Post, View, ()>::new("PublicPost")
143        .resources(|post| post.public)
144        .build();
145    let viewer_relationship = RebacPolicy::new(
146        |user: &User| user.id,
147        |post: &Post| post.id,
148        Relation::Viewer,
149    );
150
151    let mut checker = PermissionChecker::new();
152    checker.add_policy(public_posts);
153    checker.add_policy(viewer_relationship);
154    checker
155}
156
157fn session_with(source: &Arc<dyn FactSource<RelationshipKey>>) -> EvaluationSession {
158    EvaluationSession::builder()
159        .with_arc::<RelationshipKey>(Arc::clone(source))
160        .build()
161}
162
163#[tokio::main]
164async fn main() {
165    let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| {
166        "host=localhost port=15432 user=postgres password=test dbname=awa_test".to_string()
167    });
168
169    let (client, connection) = tokio_postgres::connect(&database_url, NoTls)
170        .await
171        .expect("connect to PostgreSQL");
172    tokio::spawn(async move {
173        if let Err(error) = connection.await {
174            eprintln!("postgres connection error: {error}");
175        }
176    });
177    let client = Arc::new(client);
178
179    let version: String = client
180        .query_one("SELECT version()", &[])
181        .await
182        .expect("version query should succeed")
183        .get(0);
184    println!("{version}");
185
186    client
187        .batch_execute(
188            "
189            DROP TABLE IF EXISTS gatehouse_example_post_relationships;
190            CREATE UNLOGGED TABLE gatehouse_example_post_relationships (
191                subject_id uuid NOT NULL,
192                post_id uuid NOT NULL,
193                relationship text NOT NULL,
194                PRIMARY KEY (subject_id, post_id, relationship)
195            );
196            ",
197        )
198        .await
199        .expect("setup schema");
200
201    let subject = User { id: Uuid::new_v4() };
202    let posts = (0..10_000)
203        .map(|index| Post {
204            id: Uuid::new_v4(),
205            public: index % 5 == 0,
206        })
207        .collect::<Vec<_>>();
208    let granted_ids = posts
209        .iter()
210        .enumerate()
211        .filter_map(|(index, post)| (!post.public && index % 2 == 0).then_some(post.id))
212        .collect::<Vec<_>>();
213    let relationships = vec![Relation::Viewer.as_str(); granted_ids.len()];
214    let subject_ids = vec![subject.id; granted_ids.len()];
215
216    client
217        .execute(
218            "
219            INSERT INTO gatehouse_example_post_relationships (subject_id, post_id, relationship)
220            SELECT *
221            FROM unnest($1::uuid[], $2::uuid[], $3::text[])
222            ",
223            &[&subject_ids, &granted_ids, &relationships],
224        )
225        .await
226        .expect("seed grants");
227
228    let point_stmt = Arc::new(
229        client
230            .prepare(
231                "
232                SELECT EXISTS (
233                    SELECT 1
234                    FROM gatehouse_example_post_relationships
235                    WHERE subject_id = $1
236                      AND relationship = $2
237                      AND post_id = $3
238                ) AS allowed
239                ",
240            )
241            .await
242            .expect("prepare point query"),
243    );
244    let bulk_stmt = Arc::new(
245        client
246            .prepare(
247                "
248                WITH candidate_relationships AS (
249                    SELECT subject_id, post_id, relationship, ord
250                    FROM unnest($1::uuid[], $2::uuid[], $3::text[])
251                        WITH ORDINALITY AS input(subject_id, post_id, relationship, ord)
252                )
253                SELECT
254                    COALESCE(bool_or(g.post_id IS NOT NULL), false) AS allowed
255                FROM candidate_relationships c
256                LEFT JOIN gatehouse_example_post_relationships g
257                  ON g.subject_id = c.subject_id
258                 AND g.relationship = c.relationship
259                 AND g.post_id = c.post_id
260                GROUP BY c.ord, c.subject_id, c.post_id, c.relationship
261                ORDER BY c.ord
262                ",
263            )
264            .await
265            .expect("prepare bulk query"),
266    );
267
268    let source = Arc::new(PgRelationshipSource {
269        client,
270        point_stmt,
271        bulk_stmt,
272    });
273    assert_point_and_bulk_agree(
274        &source,
275        &[
276            RelationshipQuery {
277                subject_id: subject.id,
278                resource_id: posts
279                    .iter()
280                    .find(|post| granted_ids.contains(&post.id))
281                    .expect("fixture should include a granted private post")
282                    .id,
283                relation: Relation::Viewer,
284            },
285            RelationshipQuery {
286                subject_id: subject.id,
287                resource_id: posts
288                    .iter()
289                    .enumerate()
290                    .find(|(index, post)| !post.public && index % 2 == 1)
291                    .expect("fixture should include a denied private post")
292                    .1
293                    .id,
294                relation: Relation::Viewer,
295            },
296        ],
297    )
298    .await;
299    let source: Arc<dyn FactSource<RelationshipKey>> = source;
300    let checker = build_checker();
301
302    println!("size,relationship_checks,naive_ms,bulk_ms,allowed,improvement");
303    for &size in &[10usize, 100, 1_000, 5_000, 10_000] {
304        let sample = posts.iter().take(size).cloned().collect::<Vec<_>>();
305        let relationship_checks = sample.iter().filter(|post| !post.public).count();
306        let naive = measure(|| async {
307            let mut allowed = 0usize;
308            for post in &sample {
309                let session = session_with(&source);
310                if checker
311                    .evaluate_in_session(&session, &subject, &View, post, &())
312                    .await
313                    .is_granted()
314                {
315                    allowed += 1;
316                }
317            }
318            allowed
319        })
320        .await;
321
322        let bulk = measure(|| async {
323            let session = session_with(&source);
324            checker
325                .filter_authorized_in_session_by_resource(
326                    &session,
327                    &subject,
328                    &View,
329                    sample.clone(),
330                    &(),
331                    |post| post,
332                )
333                .await
334                .len()
335        })
336        .await;
337
338        assert_eq!(naive.output, bulk.output);
339        println!(
340            "{size},{relationship_checks},{:.3},{:.3},{},x{:.1}",
341            naive.elapsed.as_secs_f64() * 1_000.0,
342            bulk.elapsed.as_secs_f64() * 1_000.0,
343            bulk.output,
344            naive.elapsed.as_secs_f64() / bulk.elapsed.as_secs_f64()
345        );
346    }
347}
348
349struct Measurement<T> {
350    elapsed: Duration,
351    output: T,
352}
353
354async fn measure<F, Fut, T>(mut f: F) -> Measurement<T>
355where
356    F: FnMut() -> Fut,
357    Fut: std::future::Future<Output = T>,
358{
359    let mut best_elapsed = Duration::MAX;
360    let mut best_output = None;
361
362    for _ in 0..3 {
363        let start = Instant::now();
364        let output = f().await;
365        let elapsed = start.elapsed();
366        if elapsed < best_elapsed {
367            best_elapsed = elapsed;
368            best_output = Some(output);
369        }
370    }
371
372    Measurement {
373        elapsed: best_elapsed,
374        output: best_output.expect("measurement should run at least once"),
375    }
376}