1use async_trait::async_trait;
19use gatehouse::{
20 EvalCtx, EvaluationSession, LookupPage, LookupSource, PermissionChecker, Policy,
21 PolicyEvalResult,
22};
23use std::collections::HashMap;
24use std::fmt;
25use std::num::NonZeroUsize;
26use std::sync::Arc;
27use uuid::Uuid;
28
29#[derive(Clone, Debug)]
32struct User {
33 id: Uuid,
34 is_admin: bool,
35}
36
37#[derive(Clone, Debug)]
38struct Document {
39 id: Uuid,
40 title: String,
41}
42
43#[derive(Clone, Debug)]
44struct View;
45
46struct AdminPolicy;
50
51#[async_trait]
52impl Policy<User, Document, View, ()> for AdminPolicy {
53 async fn evaluate(&self, ctx: &EvalCtx<'_, User, Document, View, ()>) -> PolicyEvalResult {
54 if ctx.subject.is_admin {
55 ctx.grant("admin override")
56 } else {
57 ctx.deny("not admin")
58 }
59 }
60 fn policy_type(&self) -> std::borrow::Cow<'static, str> {
61 std::borrow::Cow::Borrowed("AdminPolicy")
62 }
63}
64
65struct ViewerPolicy {
68 viewers: HashMap<Uuid, Vec<Uuid>>, }
70
71#[async_trait]
72impl Policy<User, Document, View, ()> for ViewerPolicy {
73 async fn evaluate(&self, ctx: &EvalCtx<'_, User, Document, View, ()>) -> PolicyEvalResult {
74 let granted = self
75 .viewers
76 .get(&ctx.resource.id)
77 .map(|users| users.contains(&ctx.subject.id))
78 .unwrap_or(false);
79 if granted {
80 ctx.grant("viewer relation")
81 } else {
82 ctx.deny("no viewer relation")
83 }
84 }
85 fn policy_type(&self) -> std::borrow::Cow<'static, str> {
86 std::borrow::Cow::Borrowed("ViewerPolicy")
87 }
88}
89
90struct InMemoryViewerLookup {
99 per_user: HashMap<Uuid, Vec<Uuid>>,
100}
101
102#[derive(Debug)]
103struct ViewerLookupError(String);
104impl fmt::Display for ViewerLookupError {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 f.write_str(&self.0)
107 }
108}
109impl std::error::Error for ViewerLookupError {}
110
111#[async_trait]
112impl LookupSource for InMemoryViewerLookup {
113 type Subject = User;
114 type Id = Uuid;
115 type Error = ViewerLookupError;
116
117 async fn lookup_page(
118 &self,
119 subject: &User,
120 cursor: Option<&[u8]>,
121 limit: NonZeroUsize,
122 ) -> Result<LookupPage<Uuid>, ViewerLookupError> {
123 let offset = cursor
124 .map(|c| {
125 std::str::from_utf8(c)
126 .map_err(|_| ViewerLookupError("non-utf8 cursor".into()))
127 .and_then(|s| {
128 s.parse::<usize>()
129 .map_err(|_| ViewerLookupError("cursor not a number".into()))
130 })
131 })
132 .transpose()?
133 .unwrap_or(0);
134
135 let all = self.per_user.get(&subject.id).cloned().unwrap_or_default();
136
137 if offset >= all.len() {
138 return Ok(LookupPage {
139 ids: Vec::new(),
140 next_cursor: None,
141 });
142 }
143 let end = (offset + limit.get()).min(all.len());
144 let next_cursor = (end < all.len()).then(|| end.to_string().into_bytes());
145 Ok(LookupPage {
146 ids: all[offset..end].to_vec(),
147 next_cursor,
148 })
149 }
150}
151
152#[tokio::main]
155async fn main() {
156 let alice = User {
158 id: Uuid::new_v4(),
159 is_admin: false,
160 };
161 let admin = User {
162 id: Uuid::new_v4(),
163 is_admin: true,
164 };
165 let docs: Vec<Document> = (0..7)
166 .map(|i| Document {
167 id: Uuid::new_v4(),
168 title: format!("doc-{i}"),
169 })
170 .collect();
171
172 let viewer_doc_ids: Vec<Uuid> = [&docs[1], &docs[3], &docs[5]]
174 .into_iter()
175 .map(|d| d.id)
176 .collect();
177
178 let viewers: HashMap<Uuid, Vec<Uuid>> = viewer_doc_ids
179 .iter()
180 .map(|doc_id| (*doc_id, vec![alice.id]))
181 .collect();
182
183 let viewer_lookup_index: HashMap<Uuid, Vec<Uuid>> =
184 HashMap::from([(alice.id, viewer_doc_ids.clone())]);
185
186 let catalog: Arc<HashMap<Uuid, Document>> =
189 Arc::new(docs.iter().map(|d| (d.id, d.clone())).collect());
190
191 let lookup = InMemoryViewerLookup {
192 per_user: viewer_lookup_index,
193 };
194
195 let hydrator = {
199 let catalog = Arc::clone(&catalog);
200 move |ids: &[Uuid]| {
201 let catalog = Arc::clone(&catalog);
202 let ids = ids.to_vec();
203 async move {
204 Ok::<_, std::convert::Infallible>(
205 ids.iter().map(|id| catalog.get(id).cloned()).collect(),
206 )
207 }
208 }
209 };
210
211 let mut checker = PermissionChecker::<User, Document, View, ()>::new();
215 checker.add_policy(AdminPolicy);
216 checker.add_policy(ViewerPolicy { viewers });
217
218 let session = EvaluationSession::empty();
219 let page_size = NonZeroUsize::new(2).unwrap();
220
221 let alice_visible = checker
223 .lookup_authorized(&session, &alice, &View, &(), &lookup, page_size, &hydrator)
224 .await
225 .expect("lookup ok");
226 println!("Alice sees {} document(s):", alice_visible.len());
227 for doc in &alice_visible {
228 println!(" - {} ({})", doc.title, doc.id);
229 }
230 let alice_visible_ids: Vec<Uuid> = alice_visible.iter().map(|doc| doc.id).collect();
231 assert_eq!(
232 alice_visible_ids, viewer_doc_ids,
233 "the lookup + policy stack should authorize exactly the viewer-granted documents, in source order"
234 );
235
236 let admin_via_lookup = checker
244 .lookup_authorized(&session, &admin, &View, &(), &lookup, page_size, &hydrator)
245 .await
246 .expect("lookup ok");
247 println!(
248 "\nAdmin via the viewer-lookup sees {} document(s) — this is bounded \
249 by what the source enumerates; admin grants still apply at point checks.",
250 admin_via_lookup.len()
251 );
252 assert!(
253 admin_via_lookup.is_empty(),
254 "the viewer lookup enumerates nothing for the admin, so the listing is empty"
255 );
256
257 let any_doc = &docs[0];
260 let admin_point = checker
261 .evaluate_in_session(&session, &admin, &View, any_doc, &())
262 .await;
263 println!(
264 "\nAdmin point check on '{}': {}",
265 any_doc.title,
266 if admin_point.is_granted() {
267 "Granted"
268 } else {
269 "Denied"
270 }
271 );
272 admin_point.assert_granted_by("AdminPolicy");
273
274 println!("\nStreaming Alice's visible documents page-by-page:");
278 let mut cursor: Option<Vec<u8>> = None;
279 let mut page_index = 0;
280 let mut streamed_total = 0;
281 loop {
282 let page = checker
283 .lookup_authorized_page(
284 &session,
285 &alice,
286 &View,
287 &(),
288 &lookup,
289 cursor.as_deref(),
290 page_size,
291 &hydrator,
292 )
293 .await
294 .expect("lookup_authorized_page ok");
295 println!(" page {page_index}: {} authorized", page.resources.len());
296 page_index += 1;
297 streamed_total += page.resources.len();
298 match page.next_cursor {
299 None => break,
300 Some(next) => cursor = Some(next),
301 }
302 }
303 assert_eq!(page_index, 2);
306 assert_eq!(streamed_total, viewer_doc_ids.len());
307}