1use async_trait::async_trait;
13use gatehouse::{
14 EvaluationSession, FactLoadError, FactLoadResult, FactSource, PermissionChecker, PolicyBuilder,
15 RebacPolicy, RelationshipQuery,
16};
17use std::fmt;
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20use tokio_postgres::{Client, NoTls, Statement};
21use uuid::Uuid;
22
23type RelationshipKey = RelationshipQuery<Uuid, Uuid, Relation>;
24
25#[derive(Clone)]
26struct User {
27 id: Uuid,
28}
29
30#[derive(Clone)]
31struct Post {
32 id: Uuid,
33 public: bool,
34}
35
36struct View;
37
38#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
39enum Relation {
40 Viewer,
41}
42
43impl Relation {
44 fn as_str(self) -> &'static str {
45 match self {
46 Self::Viewer => "viewer",
47 }
48 }
49}
50
51impl fmt::Display for Relation {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 f.write_str(self.as_str())
54 }
55}
56
57#[derive(Clone)]
58struct PgRelationshipSource {
59 client: Arc<Client>,
60 point_stmt: Arc<Statement>,
61 bulk_stmt: Arc<Statement>,
62}
63
64impl PgRelationshipSource {
65 async fn load_point(&self, key: &RelationshipKey) -> FactLoadResult<bool> {
66 let relationship = key.relation.as_str();
67 match self
68 .client
69 .query_one(
70 &*self.point_stmt,
71 &[&key.subject_id, &relationship, &key.resource_id],
72 )
73 .await
74 {
75 Ok(row) => FactLoadResult::Found(row.get("allowed")),
76 Err(error) => FactLoadResult::Error(FactLoadError::backend(error)),
77 }
78 }
79
80 async fn load_bulk(&self, keys: &[RelationshipKey]) -> Vec<FactLoadResult<bool>> {
81 let subject_ids = keys.iter().map(|key| key.subject_id).collect::<Vec<_>>();
82 let post_ids = keys.iter().map(|key| key.resource_id).collect::<Vec<_>>();
83 let relationships = keys
84 .iter()
85 .map(|key| key.relation.as_str())
86 .collect::<Vec<_>>();
87
88 match self
89 .client
90 .query(&*self.bulk_stmt, &[&subject_ids, &post_ids, &relationships])
91 .await
92 {
93 Ok(rows) => rows
94 .into_iter()
95 .map(|row| FactLoadResult::Found(row.get("allowed")))
96 .collect(),
97 Err(error) => {
98 let error = FactLoadError::backend(error);
99 keys.iter()
100 .map(|_| FactLoadResult::Error(error.clone()))
101 .collect()
102 }
103 }
104 }
105}
106
107#[async_trait]
108impl FactSource<RelationshipKey> for PgRelationshipSource {
109 async fn load_many(&self, keys: &[RelationshipKey]) -> Vec<FactLoadResult<bool>> {
110 if keys.len() == 1 {
111 return vec![self.load_point(&keys[0]).await];
112 }
113
114 self.load_bulk(keys).await
115 }
116}
117
118async fn assert_point_and_bulk_agree(source: &PgRelationshipSource, keys: &[RelationshipKey]) {
119 for key in keys {
120 let point = source.load_point(key).await;
121 let bulk = source
122 .load_bulk(std::slice::from_ref(key))
123 .await
124 .into_iter()
125 .next()
126 .expect("bulk load for one key should return one result");
127
128 match (point, bulk) {
129 (FactLoadResult::Found(point), FactLoadResult::Found(bulk)) => {
130 assert_eq!(point, bulk, "point and bulk SQL should agree for {key:?}");
131 }
132 (point, bulk) => {
133 panic!(
134 "point and bulk SQL should both succeed in the example: {point:?} vs {bulk:?}"
135 );
136 }
137 }
138 }
139}
140
141fn build_checker() -> PermissionChecker<User, Post, View, ()> {
142 let public_posts = PolicyBuilder::<User, Post, View, ()>::new("PublicPost")
143 .resources(|post| post.public)
144 .build();
145 let viewer_relationship = RebacPolicy::new(
146 |user: &User| user.id,
147 |post: &Post| post.id,
148 Relation::Viewer,
149 );
150
151 let mut checker = PermissionChecker::new();
152 checker.add_policy(public_posts);
153 checker.add_policy(viewer_relationship);
154 checker
155}
156
157fn session_with(source: &Arc<dyn FactSource<RelationshipKey>>) -> EvaluationSession {
158 EvaluationSession::builder()
159 .with_arc::<RelationshipKey>(Arc::clone(source))
160 .build()
161}
162
163#[tokio::main]
164async fn main() {
165 let database_url = std::env::var("DATABASE_URL").unwrap_or_else(|_| {
166 "host=localhost port=15432 user=postgres password=test dbname=awa_test".to_string()
167 });
168
169 let (client, connection) = tokio_postgres::connect(&database_url, NoTls)
170 .await
171 .expect("connect to PostgreSQL");
172 tokio::spawn(async move {
173 if let Err(error) = connection.await {
174 eprintln!("postgres connection error: {error}");
175 }
176 });
177 let client = Arc::new(client);
178
179 let version: String = client
180 .query_one("SELECT version()", &[])
181 .await
182 .expect("version query should succeed")
183 .get(0);
184 println!("{version}");
185
186 client
187 .batch_execute(
188 "
189 DROP TABLE IF EXISTS gatehouse_example_post_relationships;
190 CREATE UNLOGGED TABLE gatehouse_example_post_relationships (
191 subject_id uuid NOT NULL,
192 post_id uuid NOT NULL,
193 relationship text NOT NULL,
194 PRIMARY KEY (subject_id, post_id, relationship)
195 );
196 ",
197 )
198 .await
199 .expect("setup schema");
200
201 let subject = User { id: Uuid::new_v4() };
202 let posts = (0..10_000)
203 .map(|index| Post {
204 id: Uuid::new_v4(),
205 public: index % 5 == 0,
206 })
207 .collect::<Vec<_>>();
208 let granted_ids = posts
209 .iter()
210 .enumerate()
211 .filter_map(|(index, post)| (!post.public && index % 2 == 0).then_some(post.id))
212 .collect::<Vec<_>>();
213 let relationships = vec![Relation::Viewer.as_str(); granted_ids.len()];
214 let subject_ids = vec![subject.id; granted_ids.len()];
215
216 client
217 .execute(
218 "
219 INSERT INTO gatehouse_example_post_relationships (subject_id, post_id, relationship)
220 SELECT *
221 FROM unnest($1::uuid[], $2::uuid[], $3::text[])
222 ",
223 &[&subject_ids, &granted_ids, &relationships],
224 )
225 .await
226 .expect("seed grants");
227
228 let point_stmt = Arc::new(
229 client
230 .prepare(
231 "
232 SELECT EXISTS (
233 SELECT 1
234 FROM gatehouse_example_post_relationships
235 WHERE subject_id = $1
236 AND relationship = $2
237 AND post_id = $3
238 ) AS allowed
239 ",
240 )
241 .await
242 .expect("prepare point query"),
243 );
244 let bulk_stmt = Arc::new(
245 client
246 .prepare(
247 "
248 WITH candidate_relationships AS (
249 SELECT subject_id, post_id, relationship, ord
250 FROM unnest($1::uuid[], $2::uuid[], $3::text[])
251 WITH ORDINALITY AS input(subject_id, post_id, relationship, ord)
252 )
253 SELECT
254 COALESCE(bool_or(g.post_id IS NOT NULL), false) AS allowed
255 FROM candidate_relationships c
256 LEFT JOIN gatehouse_example_post_relationships g
257 ON g.subject_id = c.subject_id
258 AND g.relationship = c.relationship
259 AND g.post_id = c.post_id
260 GROUP BY c.ord, c.subject_id, c.post_id, c.relationship
261 ORDER BY c.ord
262 ",
263 )
264 .await
265 .expect("prepare bulk query"),
266 );
267
268 let source = Arc::new(PgRelationshipSource {
269 client,
270 point_stmt,
271 bulk_stmt,
272 });
273 assert_point_and_bulk_agree(
274 &source,
275 &[
276 RelationshipQuery {
277 subject_id: subject.id,
278 resource_id: posts
279 .iter()
280 .find(|post| granted_ids.contains(&post.id))
281 .expect("fixture should include a granted private post")
282 .id,
283 relation: Relation::Viewer,
284 },
285 RelationshipQuery {
286 subject_id: subject.id,
287 resource_id: posts
288 .iter()
289 .enumerate()
290 .find(|(index, post)| !post.public && index % 2 == 1)
291 .expect("fixture should include a denied private post")
292 .1
293 .id,
294 relation: Relation::Viewer,
295 },
296 ],
297 )
298 .await;
299 let source: Arc<dyn FactSource<RelationshipKey>> = source;
300 let checker = build_checker();
301
302 println!("size,relationship_checks,naive_ms,bulk_ms,allowed,improvement");
303 for &size in &[10usize, 100, 1_000, 5_000, 10_000] {
304 let sample = posts.iter().take(size).cloned().collect::<Vec<_>>();
305 let relationship_checks = sample.iter().filter(|post| !post.public).count();
306 let naive = measure(|| async {
307 let mut allowed = 0usize;
308 for post in &sample {
309 let session = session_with(&source);
310 if checker
311 .evaluate_in_session(&session, &subject, &View, post, &())
312 .await
313 .is_granted()
314 {
315 allowed += 1;
316 }
317 }
318 allowed
319 })
320 .await;
321
322 let bulk = measure(|| async {
323 let session = session_with(&source);
324 checker
325 .filter_authorized_in_session_by_resource(
326 &session,
327 &subject,
328 &View,
329 sample.clone(),
330 &(),
331 |post| post,
332 )
333 .await
334 .len()
335 })
336 .await;
337
338 assert_eq!(naive.output, bulk.output);
339 println!(
340 "{size},{relationship_checks},{:.3},{:.3},{},x{:.1}",
341 naive.elapsed.as_secs_f64() * 1_000.0,
342 bulk.elapsed.as_secs_f64() * 1_000.0,
343 bulk.output,
344 naive.elapsed.as_secs_f64() / bulk.elapsed.as_secs_f64()
345 );
346 }
347}
348
349struct Measurement<T> {
350 elapsed: Duration,
351 output: T,
352}
353
354async fn measure<F, Fut, T>(mut f: F) -> Measurement<T>
355where
356 F: FnMut() -> Fut,
357 Fut: std::future::Future<Output = T>,
358{
359 let mut best_elapsed = Duration::MAX;
360 let mut best_output = None;
361
362 for _ in 0..3 {
363 let start = Instant::now();
364 let output = f().await;
365 let elapsed = start.elapsed();
366 if elapsed < best_elapsed {
367 best_elapsed = elapsed;
368 best_output = Some(output);
369 }
370 }
371
372 Measurement {
373 elapsed: best_elapsed,
374 output: best_output.expect("measurement should run at least once"),
375 }
376}