1use async_trait::async_trait;
41use gatehouse::{
42 EvalCtx, EvaluationSession, FactKey, FactLoadResult, FactSource, PermissionChecker, Policy,
43 PolicyEvalResult,
44};
45use std::borrow::Cow;
46use std::sync::atomic::{AtomicUsize, Ordering};
47use std::sync::Arc;
48use uuid::Uuid;
49
50#[derive(Debug, Clone)]
58struct Supplier {
59 #[allow(dead_code)]
60 user_id: Uuid,
61 org_id: Uuid,
62}
63
64#[derive(Debug, Clone)]
65struct Invoice {
66 #[allow(dead_code)]
67 id: Uuid,
68 customer_id: Uuid,
69}
70
71#[derive(Debug, Clone)]
72struct ViewAction;
73
74struct HierarchyService {
80 routes: std::collections::HashMap<Uuid, Uuid>,
82 call_count: AtomicUsize,
85}
86
87impl HierarchyService {
88 fn new(routes: std::collections::HashMap<Uuid, Uuid>) -> Self {
89 Self {
90 routes,
91 call_count: AtomicUsize::new(0),
92 }
93 }
94
95 async fn resolve_customer(&self, org_id: Uuid) -> Option<Uuid> {
96 self.call_count.fetch_add(1, Ordering::SeqCst);
97 tokio::task::yield_now().await;
99 self.routes.get(&org_id).copied()
100 }
101
102 fn calls(&self) -> usize {
103 self.call_count.load(Ordering::SeqCst)
104 }
105
106 fn reset(&self) {
107 self.call_count.store(0, Ordering::SeqCst);
108 }
109}
110
111struct WrongSupplierPolicy {
118 hierarchy: Arc<HierarchyService>,
119}
120
121#[async_trait]
122impl Policy<Supplier, Invoice, ViewAction, ()> for WrongSupplierPolicy {
123 async fn evaluate(
124 &self,
125 ctx: &EvalCtx<'_, Supplier, Invoice, ViewAction, ()>,
126 ) -> PolicyEvalResult {
127 let resolved = self.hierarchy.resolve_customer(ctx.subject.org_id).await;
130 match resolved {
131 Some(customer_id) if customer_id == ctx.resource.customer_id => {
132 ctx.grant("subject's supplier org bills under the invoice's customer")
133 }
134 _ => ctx.deny("subject's supplier org does not bill under the invoice's customer"),
135 }
136 }
137 fn policy_type(&self) -> Cow<'static, str> {
138 Cow::Borrowed("WrongSupplierPolicy")
139 }
140}
141
142#[derive(Debug, Clone, Hash, PartialEq, Eq)]
149struct CustomerForOrg(Uuid);
150
151impl FactKey for CustomerForOrg {
152 const NAME: &'static str = "customer_for_org";
153 type Value = Option<Uuid>;
154}
155
156struct CustomerForOrgSource {
161 hierarchy: Arc<HierarchyService>,
162 load_many_calls: Arc<AtomicUsize>,
166}
167
168#[async_trait]
169impl FactSource<CustomerForOrg> for CustomerForOrgSource {
170 async fn load_many(&self, keys: &[CustomerForOrg]) -> Vec<FactLoadResult<Option<Uuid>>> {
171 self.load_many_calls.fetch_add(1, Ordering::SeqCst);
172 let mut out = Vec::with_capacity(keys.len());
176 for CustomerForOrg(org_id) in keys {
177 out.push(FactLoadResult::Found(
178 self.hierarchy.resolve_customer(*org_id).await,
179 ));
180 }
181 out
182 }
183}
184
185struct RightSupplierPolicy;
186
187#[async_trait]
188impl Policy<Supplier, Invoice, ViewAction, ()> for RightSupplierPolicy {
189 async fn evaluate(
190 &self,
191 ctx: &EvalCtx<'_, Supplier, Invoice, ViewAction, ()>,
192 ) -> PolicyEvalResult {
193 match ctx.session.get(CustomerForOrg(ctx.subject.org_id)).await {
198 FactLoadResult::Found(Some(customer_id)) if customer_id == ctx.resource.customer_id => {
199 ctx.grant("subject's supplier org bills under the invoice's customer")
200 }
201 _ => ctx.deny("subject's supplier org does not bill under the invoice's customer"),
202 }
203 }
204 fn policy_type(&self) -> Cow<'static, str> {
205 Cow::Borrowed("RightSupplierPolicy")
206 }
207}
208
209#[tokio::main]
212async fn main() {
213 let supplier_org = Uuid::new_v4();
215 let customer = Uuid::new_v4();
216 let supplier = Supplier {
217 user_id: Uuid::new_v4(),
218 org_id: supplier_org,
219 };
220 let routes = std::collections::HashMap::from([(supplier_org, customer)]);
221 let hierarchy = Arc::new(HierarchyService::new(routes));
222
223 let invoices: Vec<Invoice> = (0..25)
224 .map(|_| Invoice {
225 id: Uuid::new_v4(),
226 customer_id: customer,
227 })
228 .collect();
229
230 let mut wrong_checker = PermissionChecker::<Supplier, Invoice, ViewAction, ()>::new();
232 wrong_checker.add_policy(WrongSupplierPolicy {
233 hierarchy: Arc::clone(&hierarchy),
234 });
235
236 hierarchy.reset();
237 let session = EvaluationSession::empty();
238 let visible = wrong_checker
239 .filter_authorized_in_session_by_resource(
240 &session,
241 &supplier,
242 &ViewAction,
243 invoices.clone(),
244 &(),
245 |i| i,
246 )
247 .await;
248 let wrong_calls = hierarchy.calls();
249 println!(
250 "[wrong] {} invoices -> {} hierarchy lookups (N+1, redundant)",
251 visible.len(),
252 wrong_calls,
253 );
254 assert_eq!(
258 wrong_calls, 25,
259 "the wrong shape pays one hierarchy call per item",
260 );
261 assert_eq!(visible.len(), 25);
262
263 let mut right_checker = PermissionChecker::<Supplier, Invoice, ViewAction, ()>::new();
265 right_checker.add_policy(RightSupplierPolicy);
266
267 hierarchy.reset();
268 let load_many_calls = Arc::new(AtomicUsize::new(0));
269 let session = EvaluationSession::builder()
270 .with_arc::<CustomerForOrg>(Arc::new(CustomerForOrgSource {
271 hierarchy: Arc::clone(&hierarchy),
272 load_many_calls: Arc::clone(&load_many_calls),
273 }))
274 .build();
275 let visible = right_checker
276 .filter_authorized_in_session_by_resource(
277 &session,
278 &supplier,
279 &ViewAction,
280 invoices,
281 &(),
282 |i| i,
283 )
284 .await;
285 let right_calls = hierarchy.calls();
286 let batch_calls = load_many_calls.load(Ordering::SeqCst);
287 println!(
288 "[right] {} invoices -> {} hierarchy lookup ({} batched load_many call, deduped through the session)",
289 visible.len(),
290 right_calls,
291 batch_calls,
292 );
293 assert_eq!(
294 right_calls, 1,
295 "the session deduplicates: one supplier_org, one backend call",
296 );
297 assert_eq!(
298 batch_calls, 1,
299 "the session batches: one load_many call covering the unique key set",
300 );
301 assert_eq!(visible.len(), 25);
302}