1use async_trait::async_trait;
18use axum::{
19 extract::{FromRequestParts, Path, State},
20 http::{request::Parts, HeaderMap, StatusCode},
21 response::IntoResponse,
22 routing::{get, post},
23 Json, Router,
24};
25use gatehouse::*;
26use serde::Serialize;
27use std::collections::HashSet;
28use std::fmt;
29use std::sync::Arc;
30use std::time::{Duration, SystemTime};
31use uuid::Uuid;
32
33#[derive(Debug, Clone)]
38pub struct User {
39 pub id: Uuid,
40 pub roles: Vec<String>,
41}
42
43#[derive(Debug, Clone)]
44pub struct AuthenticatedUser(pub User);
45
46impl<S> FromRequestParts<S> for AuthenticatedUser
47where
48 S: Send + Sync,
49{
50 type Rejection = (StatusCode, String);
51
52 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
53 let id = parts
54 .headers
55 .get("x-user-id")
56 .and_then(|value| value.to_str().ok())
57 .and_then(|raw| Uuid::parse_str(raw).ok())
58 .unwrap_or_else(Uuid::nil);
59
60 let roles = parts
61 .headers
62 .get("x-roles")
63 .and_then(|value| value.to_str().ok())
64 .map(|raw| {
65 raw.split(',')
66 .map(|role| role.trim().to_ascii_lowercase())
67 .filter(|role| !role.is_empty())
68 .collect::<Vec<_>>()
69 })
70 .unwrap_or_else(|| vec!["viewer".to_string()]);
71
72 Ok(AuthenticatedUser(User { id, roles }))
73 }
74}
75
76fn parse_bool(value: &str) -> Option<bool> {
77 match value.trim().to_ascii_lowercase().as_str() {
78 "true" | "1" | "yes" => Some(true),
79 "false" | "0" | "no" => Some(false),
80 _ => None,
81 }
82}
83
84#[derive(Debug, Default, Clone)]
87pub struct InvoiceOverrides {
88 locked: Option<bool>,
89 age_days: Option<u64>,
90}
91
92impl InvoiceOverrides {
93 pub fn from_headers(headers: &HeaderMap) -> Self {
94 let locked = headers
95 .get("x-invoice-locked")
96 .and_then(|value| value.to_str().ok())
97 .and_then(parse_bool);
98
99 let age_days = headers
100 .get("x-invoice-age-days")
101 .and_then(|value| value.to_str().ok())
102 .and_then(|raw| raw.parse::<u64>().ok());
103
104 Self { locked, age_days }
105 }
106
107 fn build_invoice(&self, invoice_id: Uuid) -> Invoice {
108 Invoice {
109 id: invoice_id,
110 owner_id: demo_owner_id(),
111 locked: self.locked.unwrap_or(false),
112 created_at: SystemTime::now()
113 - Duration::from_secs(self.age_days.unwrap_or(10) * 24 * 60 * 60),
114 }
115 }
116}
117
118impl<S> FromRequestParts<S> for InvoiceOverrides
119where
120 S: Send + Sync,
121{
122 type Rejection = (StatusCode, String);
123
124 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
125 Ok(Self::from_headers(&parts.headers))
126 }
127}
128
129#[derive(Debug, Clone)]
131pub enum Action {
132 Edit,
133 View,
134}
135
136#[derive(Debug, Clone)]
139pub struct Invoice {
140 pub id: Uuid,
141 pub owner_id: Uuid,
142 pub locked: bool,
143 pub created_at: SystemTime,
144}
145
146#[derive(Debug, Clone)]
149pub struct RequestContext {
150 pub current_time: SystemTime,
151}
152
153impl RequestContext {
154 fn now() -> Self {
155 Self {
156 current_time: SystemTime::now(),
157 }
158 }
159}
160
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
162pub enum Relation {
163 Viewer,
164}
165
166impl fmt::Display for Relation {
167 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168 match self {
169 Self::Viewer => f.write_str("viewer"),
170 }
171 }
172}
173
174type InvoiceRelationship = RelationshipQuery<Uuid, Uuid, Relation>;
175
176#[derive(Clone)]
177pub struct InMemoryRelationshipSource {
178 grants: Arc<HashSet<InvoiceRelationship>>,
179}
180
181impl InMemoryRelationshipSource {
182 fn new(grants: impl IntoIterator<Item = InvoiceRelationship>) -> Self {
183 Self {
184 grants: Arc::new(grants.into_iter().collect()),
185 }
186 }
187}
188
189#[async_trait]
190impl FactSource<InvoiceRelationship> for InMemoryRelationshipSource {
191 async fn load_many(&self, keys: &[InvoiceRelationship]) -> Vec<FactLoadResult<bool>> {
192 keys.iter()
193 .map(|key| FactLoadResult::Found(self.grants.contains(key)))
194 .collect()
195 }
196}
197
198#[derive(Debug, Clone, Serialize)]
199pub struct InvoiceSummary {
200 pub id: Uuid,
201 pub owner_id: Uuid,
202 pub locked: bool,
203}
204
205impl From<Invoice> for InvoiceSummary {
206 fn from(invoice: Invoice) -> Self {
207 Self {
208 id: invoice.id,
209 owner_id: invoice.owner_id,
210 locked: invoice.locked,
211 }
212 }
213}
214
215#[derive(Clone)]
223pub struct AppState {
224 checker: PermissionChecker<User, Invoice, Action, RequestContext>,
225 invoice_relationships: Arc<dyn FactSource<InvoiceRelationship>>,
226 invoices: Arc<Vec<Invoice>>,
227}
228
229impl AppState {
230 pub fn demo() -> Self {
231 let viewer_id = demo_viewer_id();
232 let invoices = Arc::new(demo_invoices());
233 let grants = invoices
236 .iter()
237 .filter(|invoice| invoice.owner_id != demo_owner_id())
238 .map(|invoice| InvoiceRelationship {
239 subject_id: viewer_id,
240 resource_id: invoice.id,
241 relation: Relation::Viewer,
242 });
243
244 Self {
245 checker: build_permission_checker(),
246 invoice_relationships: Arc::new(InMemoryRelationshipSource::new(grants)),
247 invoices,
248 }
249 }
250
251 fn request_session(&self) -> EvaluationSession {
252 EvaluationSession::builder()
253 .with_arc::<InvoiceRelationship>(Arc::clone(&self.invoice_relationships))
254 .build()
255 }
256}
257
258fn demo_owner_id() -> Uuid {
259 Uuid::parse_str("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa").unwrap()
260}
261
262fn demo_viewer_id() -> Uuid {
263 Uuid::parse_str("eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee").unwrap()
264}
265
266fn demo_invoices() -> Vec<Invoice> {
267 vec![
268 Invoice {
269 id: Uuid::parse_str("11111111-1111-1111-1111-111111111111").unwrap(),
270 owner_id: demo_owner_id(),
271 locked: false,
272 created_at: SystemTime::now() - Duration::from_secs(10 * 24 * 60 * 60),
273 },
274 Invoice {
275 id: Uuid::parse_str("22222222-2222-2222-2222-222222222222").unwrap(),
276 owner_id: Uuid::parse_str("cccccccc-cccc-cccc-cccc-cccccccccccc").unwrap(),
277 locked: false,
278 created_at: SystemTime::now() - Duration::from_secs(5 * 24 * 60 * 60),
279 },
280 Invoice {
281 id: Uuid::parse_str("33333333-3333-3333-3333-333333333333").unwrap(),
282 owner_id: Uuid::parse_str("dddddddd-dddd-dddd-dddd-dddddddddddd").unwrap(),
283 locked: true,
284 created_at: SystemTime::now() - Duration::from_secs(2 * 24 * 60 * 60),
285 },
286 ]
287}
288
289fn admin_override_policy() -> Box<dyn Policy<User, Invoice, Action, RequestContext>> {
296 PolicyBuilder::<User, Invoice, Action, RequestContext>::new("AdminOverridePolicy")
297 .when(|user, _action, _invoice, _ctx| user.roles.contains(&"admin".to_string()))
298 .build()
299}
300
301fn invoice_viewer_policy() -> Box<dyn Policy<User, Invoice, Action, RequestContext>> {
305 let is_view: Arc<dyn Policy<User, Invoice, Action, RequestContext>> = Arc::from(
306 PolicyBuilder::<User, Invoice, Action, RequestContext>::new("IsView")
307 .when(|_user, action, _invoice, _ctx| matches!(action, Action::View))
308 .build(),
309 );
310 let viewer_relationship: Arc<dyn Policy<User, Invoice, Action, RequestContext>> =
311 Arc::new(RebacPolicy::new(
312 |user: &User| user.id,
313 |invoice: &Invoice| invoice.id,
314 Relation::Viewer,
315 ));
316
317 Box::new(
318 AndPolicy::try_new(vec![is_view, viewer_relationship])
319 .expect("invoice viewer policy has the guard and relationship checks"),
320 )
321}
322
323fn invoice_editing_policy() -> Box<dyn Policy<User, Invoice, Action, RequestContext>> {
327 let is_edit = PolicyBuilder::<User, Invoice, Action, RequestContext>::new("IsEdit")
328 .when(|_user, action, _invoice, _ctx| matches!(action, Action::Edit))
329 .build();
330
331 let is_owner = PolicyBuilder::<User, Invoice, Action, RequestContext>::new("IsOwnerOfInvoice")
332 .when(|user, _action, invoice, _ctx| user.id == invoice.owner_id)
333 .build();
334
335 let invoice_not_locked =
336 PolicyBuilder::<User, Invoice, Action, RequestContext>::new("InvoiceNotLocked")
337 .when(|_user, _action, invoice, _ctx| !invoice.locked)
338 .build();
339
340 const THIRTY_DAYS: u64 = 30 * 24 * 60 * 60;
341 let invoice_age_under_30_days =
342 PolicyBuilder::<User, Invoice, Action, RequestContext>::new("InvoiceAgeUnder30Days")
343 .when(move |_user, _action, invoice, ctx| {
344 ctx.current_time
345 .duration_since(invoice.created_at)
346 .unwrap_or_default()
347 .as_secs()
348 <= THIRTY_DAYS
349 })
350 .build();
351
352 Box::new(
353 AndPolicy::try_new(vec![
354 Arc::from(is_edit),
355 Arc::from(is_owner),
356 Arc::from(invoice_not_locked),
357 Arc::from(invoice_age_under_30_days),
358 ])
359 .expect("invoice editing policy has at least one rule"),
360 )
361}
362
363pub fn build_permission_checker() -> PermissionChecker<User, Invoice, Action, RequestContext> {
366 let mut checker = PermissionChecker::named("InvoiceChecker");
367 checker.add_policy(admin_override_policy());
368 checker.add_policy(invoice_viewer_policy());
369 checker.add_policy(invoice_editing_policy());
370 checker
371}
372
373pub async fn view_invoice_handler(
378 Path(invoice_id): Path<Uuid>,
379 State(state): State<AppState>,
380 AuthenticatedUser(user): AuthenticatedUser,
381 overrides: InvoiceOverrides,
382) -> impl IntoResponse {
383 let invoice = overrides.build_invoice(invoice_id);
385 let session = state.request_session();
386
387 if state
388 .checker
389 .evaluate_in_session(
390 &session,
391 &user,
392 &Action::View,
393 &invoice,
394 &RequestContext::now(),
395 )
396 .await
397 .is_granted()
398 {
399 (StatusCode::OK, format!("{invoice:?}")).into_response()
400 } else {
401 (
402 StatusCode::FORBIDDEN,
403 "You are not authorized to view this invoice",
404 )
405 .into_response()
406 }
407}
408
409pub async fn list_invoices_handler(
410 State(state): State<AppState>,
411 AuthenticatedUser(user): AuthenticatedUser,
412) -> impl IntoResponse {
413 let session = state.request_session();
414 let candidates = state.invoices.as_ref().clone();
415
416 let visible = state
420 .checker
421 .filter_authorized_in_session_by_resource(
422 &session,
423 &user,
424 &Action::View,
425 candidates,
426 &RequestContext::now(),
427 |invoice| invoice,
428 )
429 .await
430 .into_iter()
431 .map(InvoiceSummary::from)
432 .collect::<Vec<_>>();
433
434 Json(visible).into_response()
435}
436
437pub async fn edit_invoice_handler(
438 Path(invoice_id): Path<Uuid>,
439 State(state): State<AppState>,
440 AuthenticatedUser(user): AuthenticatedUser,
441 overrides: InvoiceOverrides,
442) -> impl IntoResponse {
443 let invoice = overrides.build_invoice(invoice_id);
444 let session = state.request_session();
445
446 if state
447 .checker
448 .evaluate_in_session(
449 &session,
450 &user,
451 &Action::Edit,
452 &invoice,
453 &RequestContext::now(),
454 )
455 .await
456 .is_granted()
457 {
458 (StatusCode::OK, "Invoice edited successfully").into_response()
459 } else {
460 (
461 StatusCode::FORBIDDEN,
462 "You are not authorized to edit this invoice",
463 )
464 .into_response()
465 }
466}
467
468#[tokio::main]
473async fn main() {
474 let state = AppState::demo();
477
478 let app = Router::new()
479 .route("/invoices", get(list_invoices_handler))
480 .route("/invoices/{invoice_id}", get(view_invoice_handler))
481 .route("/invoices/{invoice_id}/edit", post(edit_invoice_handler))
482 .with_state(state);
483
484 let listener = tokio::net::TcpListener::bind("0.0.0.0:8000").await.unwrap();
485 println!("Listening on http://0.0.0.0:8000");
486 axum::serve(listener, app).await.unwrap();
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use gatehouse::AccessEvaluation;
493 use std::time::{Duration, SystemTime};
494
495 fn make_invoice(owner_id: Uuid, locked: bool, age_in_days: u64) -> Invoice {
496 Invoice {
497 id: Uuid::new_v4(),
498 owner_id,
499 locked,
500 created_at: SystemTime::now() - Duration::from_secs(age_in_days * 24 * 60 * 60),
501 }
502 }
503
504 fn context_now() -> RequestContext {
505 RequestContext {
506 current_time: SystemTime::now(),
507 }
508 }
509
510 #[tokio::test]
511 async fn admin_override_allows_anything() {
512 let checker = build_permission_checker();
513 let admin = User {
514 id: Uuid::new_v4(),
515 roles: vec!["admin".to_string()],
516 };
517
518 let invoice = make_invoice(
520 Uuid::new_v4(),
521 true,
522 60,
523 );
524
525 let result = checker
526 .check(&admin, &Action::Edit, &invoice, &context_now())
527 .await;
528
529 assert!(result.is_granted(), "admin override should allow anything");
530 match result {
531 AccessEvaluation::Granted { policy_type, .. } => {
532 assert_eq!(&policy_type, "AdminOverridePolicy");
533 }
534 _ => panic!("expected admin override to grant"),
535 }
536 }
537
538 #[tokio::test]
539 async fn owner_can_edit_unlocked_recent_invoice() {
540 let checker = build_permission_checker();
541 let owner_id = Uuid::new_v4();
542 let user = User {
543 id: owner_id,
544 roles: vec!["user".to_string()],
545 };
546
547 let invoice = make_invoice(owner_id, false, 10);
548
549 let result = checker
550 .check(&user, &Action::Edit, &invoice, &context_now())
551 .await;
552
553 assert!(
554 result.is_granted(),
555 "owner should edit an unlocked invoice under 30 days old"
556 );
557 }
558
559 #[tokio::test]
560 async fn locked_invoice_cannot_be_edited() {
561 let checker = build_permission_checker();
562 let owner_id = Uuid::new_v4();
563 let user = User {
564 id: owner_id,
565 roles: vec!["user".to_string()],
566 };
567
568 let invoice = make_invoice(owner_id, true, 10);
569
570 let result = checker
571 .check(&user, &Action::Edit, &invoice, &context_now())
572 .await;
573
574 assert!(!result.is_granted(), "a locked invoice should be denied");
575
576 if let AccessEvaluation::Denied { trace, .. } = result {
577 let trace_str = trace.format();
578 assert!(
579 trace_str.contains("InvoiceNotLocked"),
580 "expected InvoiceNotLocked to fail in trace:\n{trace_str}"
581 );
582 }
583 }
584
585 #[tokio::test]
586 async fn non_owner_cannot_edit() {
587 let checker = build_permission_checker();
588 let user = User {
589 id: Uuid::new_v4(),
590 roles: vec!["user".to_string()],
591 };
592
593 let invoice = make_invoice(
594 Uuid::new_v4(),
595 false,
596 10,
597 );
598
599 let result = checker
600 .check(&user, &Action::Edit, &invoice, &context_now())
601 .await;
602
603 assert!(!result.is_granted(), "a non-owner should be denied");
604 if let AccessEvaluation::Denied { trace, .. } = result {
605 assert!(
606 trace.format().contains("IsOwnerOfInvoice"),
607 "expected IsOwnerOfInvoice to fail in trace"
608 );
609 }
610 }
611
612 #[tokio::test]
613 async fn stale_invoice_cannot_be_edited() {
614 let checker = build_permission_checker();
615 let owner_id = Uuid::new_v4();
616 let user = User {
617 id: owner_id,
618 roles: vec!["user".to_string()],
619 };
620
621 let invoice = make_invoice(owner_id, false, 31);
623
624 let result = checker
625 .check(&user, &Action::Edit, &invoice, &context_now())
626 .await;
627 assert!(
628 !result.is_granted(),
629 "an invoice older than 30 days should be denied"
630 );
631 }
632}
633
634#[cfg(test)]
635mod integration_tests {
636 use super::*;
637 use axum::{
638 body::Body,
639 http::{Request, StatusCode},
640 Router,
641 };
642 use tower::ServiceExt;
643
644 fn test_app() -> Router {
645 Router::new()
646 .route("/invoices/{invoice_id}/edit", post(edit_invoice_handler))
647 .with_state(AppState::demo())
648 }
649
650 #[tokio::test]
651 async fn edit_invoice_handler_allows_admin() {
652 let app = test_app();
653
654 let req = Request::builder()
655 .method("POST")
656 .uri("/invoices/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/edit")
657 .header("x-roles", "admin")
658 .body(Body::empty())
659 .unwrap();
660
661 let response = app.clone().oneshot(req).await.unwrap();
662 assert_eq!(response.status(), StatusCode::OK);
663 }
664
665 #[tokio::test]
666 async fn edit_invoice_handler_denies_regular_user_if_locked() {
667 let app = test_app();
668
669 let req = Request::builder()
670 .method("POST")
671 .uri("/invoices/cccccccc-cccc-cccc-cccc-cccccccccccc/edit")
672 .header("x-user-id", "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
673 .header("x-roles", "author")
674 .header("x-invoice-locked", "true")
675 .body(Body::empty())
676 .unwrap();
677
678 let response = app.clone().oneshot(req).await.unwrap();
679 assert_eq!(response.status(), StatusCode::FORBIDDEN);
680 }
681}