1use std::any::{Any, TypeId};
55use std::collections::HashMap;
56use std::pin::Pin;
57use std::sync::{Arc, RwLock};
58
59use http::StatusCode;
60
61use crate::session::Session;
62
63pub type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
67
68#[derive(Clone)]
79pub struct PolicyContext {
80 pub session: Session,
84
85 pub user_id: Option<String>,
88
89 pub roles: Vec<String>,
92
93 #[cfg(feature = "db")]
97 pub pool:
98 Option<diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>,
99
100 pub policy_registry: PolicyRegistry,
106}
107
108impl PolicyContext {
109 pub async fn from_session(session: &Session, auth_session_key: &str) -> Self {
117 let user_id = session.get(auth_session_key).await;
118 let role = session.get("role").await;
119 let roles = role.into_iter().collect();
120 Self {
121 session: session.clone(),
122 user_id,
123 roles,
124 #[cfg(feature = "db")]
125 pool: None,
126 policy_registry: PolicyRegistry::default(),
127 }
128 }
129
130 pub async fn from_request(state: &crate::AppState, session: &Session) -> Self {
134 let mut ctx = Self::from_session(session, state.auth_session_key()).await;
135 ctx.policy_registry = state.policy_registry().clone();
136 #[cfg(feature = "db")]
137 {
138 if let Some(pool) = state.pool() {
139 ctx.pool = Some(pool.clone());
140 }
141 }
142 ctx
143 }
144
145 #[must_use]
147 pub const fn is_authenticated(&self) -> bool {
148 self.user_id.is_some()
149 }
150
151 #[must_use]
155 pub fn user_id_i64(&self) -> Option<i64> {
156 self.user_id.as_deref().and_then(|s| s.parse().ok())
157 }
158
159 #[must_use]
161 pub fn has_role(&self, role: &str) -> bool {
162 self.roles.iter().any(|r| r == role)
163 }
164
165 #[must_use]
168 pub fn has_any_role<I, S>(&self, candidates: I) -> bool
169 where
170 I: IntoIterator<Item = S>,
171 S: AsRef<str>,
172 {
173 candidates.into_iter().any(|c| self.has_role(c.as_ref()))
174 }
175
176 #[cfg(feature = "db")]
180 #[must_use]
181 pub fn with_pool(
182 mut self,
183 pool: diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>,
184 ) -> Self {
185 self.pool = Some(pool);
186 self
187 }
188}
189
190pub trait Policy<R: Send + Sync + 'static>: Send + Sync + 'static {
213 fn can_show<'a>(&'a self, _ctx: &'a PolicyContext, _resource: &'a R) -> BoxFuture<'a, bool> {
215 Box::pin(async { false })
216 }
217
218 fn can_create<'a>(&'a self, _ctx: &'a PolicyContext) -> BoxFuture<'a, bool> {
225 Box::pin(async { false })
226 }
227
228 fn can_create_payload<'a>(
232 &'a self,
233 ctx: &'a PolicyContext,
234 _payload: &'a serde_json::Value,
235 ) -> BoxFuture<'a, bool> {
236 self.can_create(ctx)
237 }
238
239 fn can_update<'a>(&'a self, _ctx: &'a PolicyContext, _resource: &'a R) -> BoxFuture<'a, bool> {
241 Box::pin(async { false })
242 }
243
244 fn can_delete<'a>(&'a self, _ctx: &'a PolicyContext, _resource: &'a R) -> BoxFuture<'a, bool> {
246 Box::pin(async { false })
247 }
248
249 fn can<'a>(
254 &'a self,
255 action: &'a str,
256 ctx: &'a PolicyContext,
257 resource: &'a R,
258 ) -> BoxFuture<'a, bool> {
259 Box::pin(async move {
260 match action {
261 "show" | "read" => self.can_show(ctx, resource).await,
262 "create" => self.can_create(ctx).await,
263 "update" | "edit" => self.can_update(ctx, resource).await,
264 "delete" | "destroy" => self.can_delete(ctx, resource).await,
265 _ => false,
266 }
267 })
268 }
269}
270
271#[cfg(feature = "db")]
286pub trait Scope<R: Send + Sync + 'static>: Send + Sync + 'static {
287 fn list<'a>(
294 &'a self,
295 _ctx: &'a PolicyContext,
296 _conn: &'a mut diesel_async::AsyncPgConnection,
297 ) -> BoxFuture<'a, crate::AutumnResult<Vec<R>>> {
298 Box::pin(async { Ok(Vec::new()) })
299 }
300}
301
302#[cfg(not(feature = "db"))]
306pub trait Scope<R: Send + Sync + 'static>: Send + Sync + 'static {
307 fn list<'a>(&'a self, _ctx: &'a PolicyContext) -> BoxFuture<'a, crate::AutumnResult<Vec<R>>> {
308 Box::pin(async { Ok(Vec::new()) })
309 }
310}
311
312pub struct ScopeQuery<'a, R: Send + Sync + 'static> {
322 ctx: &'a PolicyContext,
323 _marker: std::marker::PhantomData<fn() -> R>,
324}
325
326#[cfg(feature = "db")]
327impl<R: Send + Sync + 'static> ScopeQuery<'_, R> {
328 pub async fn load(
339 self,
340 conn: &mut diesel_async::AsyncPgConnection,
341 ) -> crate::AutumnResult<Vec<R>> {
342 let scope = self.ctx.policy_registry.scope::<R>().ok_or_else(|| {
343 crate::AutumnError::from(std::io::Error::other(format!(
344 "no scope registered for resource type {}",
345 std::any::type_name::<R>()
346 )))
347 .with_status(StatusCode::INTERNAL_SERVER_ERROR)
348 })?;
349 scope.list(self.ctx, conn).await
350 }
351}
352
353#[cfg(not(feature = "db"))]
354impl<R: Send + Sync + 'static> ScopeQuery<'_, R> {
355 pub async fn load(self) -> crate::AutumnResult<Vec<R>> {
356 let scope = self.ctx.policy_registry.scope::<R>().ok_or_else(|| {
357 crate::AutumnError::from(std::io::Error::other(format!(
358 "no scope registered for resource type {}",
359 std::any::type_name::<R>()
360 )))
361 .with_status(StatusCode::INTERNAL_SERVER_ERROR)
362 })?;
363 scope.list(self.ctx).await
364 }
365}
366
367pub trait Scoped: Send + Sync + Sized + 'static {
381 #[must_use]
384 fn scope(ctx: &PolicyContext) -> ScopeQuery<'_, Self> {
385 ScopeQuery {
386 ctx,
387 _marker: std::marker::PhantomData,
388 }
389 }
390}
391
392impl<T: Send + Sync + 'static> Scoped for T {}
393
394#[derive(Clone, Default)]
403pub struct PolicyRegistry {
404 inner: Arc<RwLock<RegistryInner>>,
405}
406
407#[derive(Default)]
408struct RegistryInner {
409 policies: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
410 scopes: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
411}
412
413impl PolicyRegistry {
414 pub fn register_policy<R, P>(&self, policy: P)
423 where
424 R: Send + Sync + 'static,
425 P: Policy<R>,
426 {
427 let mut inner = self.inner.write().expect("policy registry lock poisoned");
428 let key = TypeId::of::<R>();
429 assert!(
430 !inner.policies.contains_key(&key),
431 "Policy for {} already registered. Multiple policies per resource are not supported.",
432 std::any::type_name::<R>()
433 );
434 let boxed: Arc<dyn Policy<R>> = Arc::new(policy);
435 inner.policies.insert(key, Arc::new(boxed));
436 }
437
438 pub fn register_scope<R, S>(&self, scope: S)
444 where
445 R: Send + Sync + 'static,
446 S: Scope<R>,
447 {
448 let mut inner = self.inner.write().expect("policy registry lock poisoned");
449 let key = TypeId::of::<R>();
450 assert!(
451 !inner.scopes.contains_key(&key),
452 "Scope for {} already registered. Multiple scopes per resource are not supported.",
453 std::any::type_name::<R>()
454 );
455 let boxed: Arc<dyn Scope<R>> = Arc::new(scope);
456 inner.scopes.insert(key, Arc::new(boxed));
457 }
458
459 #[must_use]
466 pub fn policy<R: Send + Sync + 'static>(&self) -> Option<Arc<dyn Policy<R>>> {
467 let inner = self.inner.read().expect("policy registry lock poisoned");
468 inner
469 .policies
470 .get(&TypeId::of::<R>())
471 .and_then(|a| a.downcast_ref::<Arc<dyn Policy<R>>>().cloned())
472 }
473
474 #[must_use]
480 pub fn scope<R: Send + Sync + 'static>(&self) -> Option<Arc<dyn Scope<R>>> {
481 let inner = self.inner.read().expect("policy registry lock poisoned");
482 inner
483 .scopes
484 .get(&TypeId::of::<R>())
485 .and_then(|a| a.downcast_ref::<Arc<dyn Scope<R>>>().cloned())
486 }
487
488 #[must_use]
494 pub fn has_policy<R: Send + Sync + 'static>(&self) -> bool {
495 self.inner
496 .read()
497 .expect("policy registry lock poisoned")
498 .policies
499 .contains_key(&TypeId::of::<R>())
500 }
501}
502
503impl std::fmt::Debug for PolicyRegistry {
504 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
505 let inner = self.inner.read().expect("policy registry lock poisoned");
506 f.debug_struct("PolicyRegistry")
507 .field("policies", &inner.policies.len())
508 .field("scopes", &inner.scopes.len())
509 .finish()
510 }
511}
512
513#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
525pub enum ForbiddenResponse {
526 Forbidden403,
528 #[default]
530 NotFound404,
531}
532
533impl ForbiddenResponse {
534 #[must_use]
536 pub const fn status(self) -> StatusCode {
537 match self {
538 Self::Forbidden403 => StatusCode::FORBIDDEN,
539 Self::NotFound404 => StatusCode::NOT_FOUND,
540 }
541 }
542
543 #[must_use]
547 pub const fn message(self) -> &'static str {
548 match self {
549 Self::Forbidden403 => "forbidden",
550 Self::NotFound404 => "not found",
551 }
552 }
553
554 #[must_use]
558 pub fn into_error(self) -> crate::AutumnError {
559 crate::AutumnError::from(std::io::Error::other(self.message())).with_status(self.status())
560 }
561}
562
563impl std::str::FromStr for ForbiddenResponse {
564 type Err = String;
565
566 fn from_str(s: &str) -> Result<Self, Self::Err> {
567 match s.trim() {
568 "403" | "forbidden" | "Forbidden" => Ok(Self::Forbidden403),
569 "404" | "not_found" | "NotFound" | "" => Ok(Self::NotFound404),
570 other => Err(format!(
571 "invalid forbidden_response: {other:?} (expected \"403\" or \"404\")"
572 )),
573 }
574 }
575}
576
577impl<'de> serde::Deserialize<'de> for ForbiddenResponse {
578 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
579 where
580 D: serde::Deserializer<'de>,
581 {
582 let raw = String::deserialize(deserializer)?;
583 raw.parse().map_err(serde::de::Error::custom)
584 }
585}
586
587pub async fn authorize<R>(
624 state: &crate::AppState,
625 session: &Session,
626 action: &str,
627 resource: &R,
628) -> crate::AutumnResult<()>
629where
630 R: Send + Sync + 'static,
631{
632 let policy = state.policy_registry().policy::<R>().ok_or_else(|| {
633 crate::AutumnError::from(std::io::Error::other(format!(
634 "no policy registered for resource type {}",
635 std::any::type_name::<R>()
636 )))
637 .with_status(StatusCode::INTERNAL_SERVER_ERROR)
638 })?;
639
640 let ctx = PolicyContext::from_request(state, session).await;
641
642 if policy.can(action, &ctx, resource).await {
643 Ok(())
644 } else {
645 Err(state.forbidden_response().into_error())
646 }
647}
648
649#[doc(hidden)]
653pub async fn __check_policy<R>(
654 state: &crate::AppState,
655 session: &Session,
656 action: &str,
657 resource: &R,
658) -> crate::AutumnResult<()>
659where
660 R: Send + Sync + 'static,
661{
662 authorize(state, session, action, resource).await
663}
664
665#[doc(hidden)]
675pub async fn __check_policy_create<R>(
676 state: &crate::AppState,
677 session: &Session,
678) -> crate::AutumnResult<()>
679where
680 R: Send + Sync + 'static,
681{
682 authorize_create::<R>(state, session).await
683}
684
685#[doc(hidden)]
694pub async fn __check_policy_create_payload<R>(
695 state: &crate::AppState,
696 session: &Session,
697 payload: &serde_json::Value,
698) -> crate::AutumnResult<()>
699where
700 R: Send + Sync + 'static,
701{
702 authorize_create_payload::<R>(state, session, payload).await
703}
704
705pub async fn authorize_create<R>(
716 state: &crate::AppState,
717 session: &Session,
718) -> crate::AutumnResult<()>
719where
720 R: Send + Sync + 'static,
721{
722 let policy = state.policy_registry().policy::<R>().ok_or_else(|| {
723 crate::AutumnError::from(std::io::Error::other(format!(
724 "no policy registered for resource type {}",
725 std::any::type_name::<R>()
726 )))
727 .with_status(StatusCode::INTERNAL_SERVER_ERROR)
728 })?;
729
730 let ctx = PolicyContext::from_request(state, session).await;
731
732 if policy.can_create(&ctx).await {
733 Ok(())
734 } else {
735 Err(state.forbidden_response().into_error())
736 }
737}
738
739pub async fn authorize_create_payload<R>(
752 state: &crate::AppState,
753 session: &Session,
754 payload: &serde_json::Value,
755) -> crate::AutumnResult<()>
756where
757 R: Send + Sync + 'static,
758{
759 let policy = state.policy_registry().policy::<R>().ok_or_else(|| {
760 crate::AutumnError::from(std::io::Error::other(format!(
761 "no policy registered for resource type {}",
762 std::any::type_name::<R>()
763 )))
764 .with_status(StatusCode::INTERNAL_SERVER_ERROR)
765 })?;
766
767 let ctx = PolicyContext::from_request(state, session).await;
768
769 if policy.can_create_payload(&ctx, payload).await {
770 Ok(())
771 } else {
772 Err(state.forbidden_response().into_error())
773 }
774}
775
776#[cfg(test)]
777mod tests {
778 use super::*;
779 use std::collections::HashMap;
780
781 #[derive(Debug, Clone, PartialEq)]
782 struct Note {
783 author_id: i64,
784 }
785
786 #[derive(Default)]
787 struct AdminOrOwnerPolicy;
788
789 impl Policy<Note> for AdminOrOwnerPolicy {
790 fn can_show<'a>(&'a self, _ctx: &'a PolicyContext, _note: &'a Note) -> BoxFuture<'a, bool> {
791 Box::pin(async { true })
792 }
793 fn can_update<'a>(&'a self, ctx: &'a PolicyContext, note: &'a Note) -> BoxFuture<'a, bool> {
794 Box::pin(
795 async move { ctx.has_role("admin") || ctx.user_id_i64() == Some(note.author_id) },
796 )
797 }
798 fn can_delete<'a>(
799 &'a self,
800 ctx: &'a PolicyContext,
801 _note: &'a Note,
802 ) -> BoxFuture<'a, bool> {
803 Box::pin(async move { ctx.has_role("admin") })
804 }
805 }
806
807 fn ctx(user_id: Option<&str>, role: Option<&str>) -> PolicyContext {
808 let session = Session::new_for_test(String::new(), HashMap::new());
809 PolicyContext {
810 session,
811 user_id: user_id.map(str::to_owned),
812 roles: role.into_iter().map(str::to_owned).collect(),
813 #[cfg(feature = "db")]
814 pool: None,
815 policy_registry: PolicyRegistry::default(),
816 }
817 }
818
819 #[tokio::test]
820 async fn default_impls_deny() {
821 struct EmptyPolicy;
822 impl Policy<Note> for EmptyPolicy {}
823 let policy = EmptyPolicy;
824 let c = ctx(Some("1"), None);
825 let n = Note { author_id: 1 };
826 assert!(!policy.can_show(&c, &n).await);
827 assert!(!policy.can_create(&c).await);
828 assert!(!policy.can_update(&c, &n).await);
829 assert!(!policy.can_delete(&c, &n).await);
830 assert!(!policy.can("publish", &c, &n).await);
831 }
832
833 #[tokio::test]
834 async fn owner_can_update() {
835 let policy = AdminOrOwnerPolicy;
836 let c = ctx(Some("42"), None);
837 let n = Note { author_id: 42 };
838 assert!(policy.can_update(&c, &n).await);
839 assert!(!policy.can_delete(&c, &n).await);
840 }
841
842 #[tokio::test]
843 async fn non_owner_cannot_update() {
844 let policy = AdminOrOwnerPolicy;
845 let c = ctx(Some("99"), None);
846 let n = Note { author_id: 42 };
847 assert!(!policy.can_update(&c, &n).await);
848 }
849
850 #[tokio::test]
851 async fn admin_can_delete() {
852 let policy = AdminOrOwnerPolicy;
853 let c = ctx(Some("99"), Some("admin"));
854 let n = Note { author_id: 42 };
855 assert!(policy.can_delete(&c, &n).await);
856 }
857
858 #[tokio::test]
859 async fn can_dispatches_named_actions() {
860 let policy = AdminOrOwnerPolicy;
861 let c = ctx(Some("42"), None);
862 let n = Note { author_id: 42 };
863 assert!(policy.can("show", &c, &n).await);
864 assert!(policy.can("update", &c, &n).await);
865 assert!(policy.can("edit", &c, &n).await);
866 assert!(!policy.can("publish", &c, &n).await);
867 }
868
869 #[test]
870 fn policy_registry_stores_and_resolves() {
871 let registry = PolicyRegistry::default();
872 registry.register_policy::<Note, _>(AdminOrOwnerPolicy);
873 assert!(registry.has_policy::<Note>());
874 assert!(registry.policy::<Note>().is_some());
875 assert!(registry.scope::<Note>().is_none());
876 }
877
878 #[test]
879 #[should_panic(expected = "already registered")]
880 fn double_policy_registration_panics() {
881 let registry = PolicyRegistry::default();
882 registry.register_policy::<Note, _>(AdminOrOwnerPolicy);
883 registry.register_policy::<Note, _>(AdminOrOwnerPolicy);
884 }
885
886 #[test]
887 fn forbidden_response_default_is_404() {
888 let resp = ForbiddenResponse::default();
889 assert_eq!(resp, ForbiddenResponse::NotFound404);
890 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
891 }
892
893 #[test]
894 fn forbidden_response_parses_strings() {
895 assert_eq!(
896 "403".parse::<ForbiddenResponse>().unwrap(),
897 ForbiddenResponse::Forbidden403
898 );
899 assert_eq!(
900 "404".parse::<ForbiddenResponse>().unwrap(),
901 ForbiddenResponse::NotFound404
902 );
903 assert_eq!(
904 "forbidden".parse::<ForbiddenResponse>().unwrap(),
905 ForbiddenResponse::Forbidden403
906 );
907 assert!("418".parse::<ForbiddenResponse>().is_err());
908 }
909
910 #[test]
911 fn policy_context_helpers() {
912 let c = ctx(Some("42"), Some("editor"));
913 assert!(c.is_authenticated());
914 assert_eq!(c.user_id_i64(), Some(42));
915 assert!(c.has_role("editor"));
916 assert!(!c.has_role("admin"));
917 assert!(c.has_any_role(["admin", "editor"]));
918 assert!(!c.has_any_role(["viewer", "guest"]));
919 }
920
921 #[test]
922 fn anonymous_context_is_not_authenticated() {
923 let c = ctx(None, None);
924 assert!(!c.is_authenticated());
925 assert!(c.user_id_i64().is_none());
926 assert!(!c.has_role("admin"));
927 assert!(!c.has_any_role(["admin", "editor"]));
928 }
929
930 #[test]
931 fn user_id_i64_handles_non_numeric_session_value() {
932 let c = ctx(Some("not-a-number"), None);
933 assert!(c.user_id_i64().is_none());
934 }
935
936 #[test]
937 fn forbidden_response_status_and_message_round_trip() {
938 assert_eq!(
939 ForbiddenResponse::Forbidden403.status(),
940 StatusCode::FORBIDDEN
941 );
942 assert_eq!(
943 ForbiddenResponse::NotFound404.status(),
944 StatusCode::NOT_FOUND
945 );
946 assert_eq!(ForbiddenResponse::Forbidden403.message(), "forbidden");
947 assert_eq!(ForbiddenResponse::NotFound404.message(), "not found");
948 }
949
950 #[test]
951 fn forbidden_response_into_error_carries_status_and_message() {
952 let err = ForbiddenResponse::NotFound404.into_error();
953 assert_eq!(err.status(), StatusCode::NOT_FOUND);
954 assert_eq!(err.to_string(), "not found");
955
956 let err = ForbiddenResponse::Forbidden403.into_error();
957 assert_eq!(err.status(), StatusCode::FORBIDDEN);
958 assert_eq!(err.to_string(), "forbidden");
959 }
960
961 #[test]
962 fn forbidden_response_parses_empty_string_as_default_404() {
963 assert_eq!(
964 "".parse::<ForbiddenResponse>().unwrap(),
965 ForbiddenResponse::NotFound404
966 );
967 assert_eq!(
968 "not_found".parse::<ForbiddenResponse>().unwrap(),
969 ForbiddenResponse::NotFound404
970 );
971 assert_eq!(
972 "NotFound".parse::<ForbiddenResponse>().unwrap(),
973 ForbiddenResponse::NotFound404
974 );
975 assert_eq!(
976 "Forbidden".parse::<ForbiddenResponse>().unwrap(),
977 ForbiddenResponse::Forbidden403
978 );
979 }
980
981 #[test]
982 fn forbidden_response_parse_error_carries_input_value() {
983 let err = "418".parse::<ForbiddenResponse>().unwrap_err();
984 assert!(err.contains("418"));
985 assert!(err.contains("403"));
986 assert!(err.contains("404"));
987 }
988
989 #[test]
990 fn forbidden_response_deserializes_from_toml() {
991 #[derive(Debug, serde::Deserialize)]
992 struct Holder {
993 value: ForbiddenResponse,
994 }
995 let h: Holder = toml::from_str(r#"value = "403""#).unwrap();
996 assert_eq!(h.value, ForbiddenResponse::Forbidden403);
997 let h: Holder = toml::from_str(r#"value = "404""#).unwrap();
998 assert_eq!(h.value, ForbiddenResponse::NotFound404);
999 let err = toml::from_str::<Holder>(r#"value = "418""#).unwrap_err();
1000 assert!(err.to_string().contains("418"));
1001 }
1002
1003 #[test]
1004 fn registry_scope_double_registration_panics_with_clear_message() {
1005 let registry = PolicyRegistry::default();
1006 registry.register_scope::<Note, _>(EmptyScope);
1007 let panicked = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1008 registry.register_scope::<Note, _>(EmptyScope);
1009 }))
1010 .unwrap_err();
1011 let msg = panicked
1012 .downcast_ref::<String>()
1013 .map(String::as_str)
1014 .or_else(|| panicked.downcast_ref::<&'static str>().copied())
1015 .unwrap_or("");
1016 assert!(
1017 msg.contains("already registered"),
1018 "expected double-registration panic, got {msg:?}"
1019 );
1020 }
1021
1022 struct OtherResource;
1023 struct OtherPolicy;
1024 impl Policy<OtherResource> for OtherPolicy {}
1025 struct ThirdResource;
1026 struct EmptyScope;
1027 impl Scope<Note> for EmptyScope {}
1028
1029 #[test]
1030 fn registry_resolves_distinct_resource_types_independently() {
1031 let registry = PolicyRegistry::default();
1032 registry.register_policy::<Note, _>(AdminOrOwnerPolicy);
1033 registry.register_policy::<OtherResource, _>(OtherPolicy);
1034
1035 assert!(registry.has_policy::<Note>());
1036 assert!(registry.has_policy::<OtherResource>());
1037 assert!(!registry.has_policy::<ThirdResource>());
1039 assert!(registry.scope::<Note>().is_none());
1040 }
1041
1042 #[test]
1043 fn registry_debug_shows_counts() {
1044 let registry = PolicyRegistry::default();
1045 registry.register_policy::<Note, _>(AdminOrOwnerPolicy);
1046 registry.register_scope::<Note, _>(EmptyScope);
1047 let dbg = format!("{registry:?}");
1048 assert!(dbg.contains("PolicyRegistry"));
1049 assert!(dbg.contains("policies"));
1050 assert!(dbg.contains("scopes"));
1051 }
1052
1053 fn detached_state_with(
1054 _registry: PolicyRegistry,
1055 forbidden: ForbiddenResponse,
1056 ) -> crate::AppState {
1057 crate::AppState::detached()
1058 .with_forbidden_response(forbidden)
1059 .with_auth_session_key("user_id")
1060 }
1061
1062 fn session_with(user_id: Option<&str>, role: Option<&str>) -> Session {
1063 let mut data = HashMap::new();
1064 if let Some(u) = user_id {
1065 data.insert("user_id".to_owned(), u.to_owned());
1066 }
1067 if let Some(r) = role {
1068 data.insert("role".to_owned(), r.to_owned());
1069 }
1070 Session::new_for_test(String::new(), data)
1071 }
1072
1073 #[tokio::test]
1074 async fn authorize_returns_500_when_no_policy_registered() {
1075 let state = detached_state_with(PolicyRegistry::default(), ForbiddenResponse::default());
1076 let session = session_with(Some("42"), None);
1077 let n = Note { author_id: 42 };
1078 let err = authorize::<Note>(&state, &session, "update", &n)
1079 .await
1080 .unwrap_err();
1081 assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
1082 }
1083
1084 #[tokio::test]
1085 async fn authorize_returns_configured_deny_when_policy_denies() {
1086 let registry = PolicyRegistry::default();
1087 registry.register_policy::<Note, _>(AdminOrOwnerPolicy);
1088 let state = detached_state_with(registry.clone(), ForbiddenResponse::Forbidden403);
1089 let live = state.policy_registry();
1091 live.register_policy::<Note, _>(AdminOrOwnerPolicy);
1094
1095 let session = session_with(Some("99"), None); let n = Note { author_id: 42 };
1097 let err = authorize::<Note>(&state, &session, "update", &n)
1098 .await
1099 .unwrap_err();
1100 assert_eq!(err.status(), StatusCode::FORBIDDEN);
1101 }
1102
1103 #[tokio::test]
1104 async fn authorize_returns_ok_when_policy_allows() {
1105 let state = crate::AppState::detached();
1106 state
1107 .policy_registry()
1108 .register_policy::<Note, _>(AdminOrOwnerPolicy);
1109 let session = session_with(Some("42"), None); let n = Note { author_id: 42 };
1111 authorize::<Note>(&state, &session, "update", &n)
1112 .await
1113 .expect("owner is allowed to update");
1114 }
1115
1116 #[tokio::test]
1117 async fn authorize_create_returns_500_when_no_policy_registered() {
1118 let state = crate::AppState::detached();
1119 let session = session_with(Some("42"), None);
1120 let err = authorize_create::<Note>(&state, &session)
1121 .await
1122 .unwrap_err();
1123 assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
1124 }
1125
1126 #[tokio::test]
1127 async fn authorize_create_dispatches_can_create() {
1128 struct AuthOnlyCreatePolicy;
1129 impl Policy<Note> for AuthOnlyCreatePolicy {
1130 fn can_create<'a>(&'a self, ctx: &'a PolicyContext) -> BoxFuture<'a, bool> {
1131 Box::pin(async move { ctx.is_authenticated() })
1132 }
1133 }
1134
1135 let state =
1136 crate::AppState::detached().with_forbidden_response(ForbiddenResponse::Forbidden403);
1137 state
1138 .policy_registry()
1139 .register_policy::<Note, _>(AuthOnlyCreatePolicy);
1140
1141 let anon = session_with(None, None);
1142 let err = authorize_create::<Note>(&state, &anon).await.unwrap_err();
1143 assert_eq!(err.status(), StatusCode::FORBIDDEN);
1144
1145 let user = session_with(Some("1"), None);
1146 authorize_create::<Note>(&state, &user)
1147 .await
1148 .expect("authenticated user passes can_create");
1149 }
1150
1151 #[tokio::test]
1152 async fn authorize_create_payload_dispatches_can_create_payload() {
1153 struct OwnerPayloadPolicy;
1154 impl Policy<Note> for OwnerPayloadPolicy {
1155 fn can_create_payload<'a>(
1156 &'a self,
1157 ctx: &'a PolicyContext,
1158 payload: &'a serde_json::Value,
1159 ) -> BoxFuture<'a, bool> {
1160 Box::pin(async move {
1161 payload.get("author_id").and_then(serde_json::Value::as_i64)
1162 == ctx.user_id_i64()
1163 })
1164 }
1165 }
1166
1167 let state =
1168 crate::AppState::detached().with_forbidden_response(ForbiddenResponse::Forbidden403);
1169 state
1170 .policy_registry()
1171 .register_policy::<Note, _>(OwnerPayloadPolicy);
1172
1173 let user = session_with(Some("1"), None);
1174 let own_payload = serde_json::json!({"author_id": 1});
1175 authorize_create_payload::<Note>(&state, &user, &own_payload)
1176 .await
1177 .expect("owner payload passes can_create_payload");
1178
1179 let other_payload = serde_json::json!({"author_id": 2});
1180 let err = authorize_create_payload::<Note>(&state, &user, &other_payload)
1181 .await
1182 .unwrap_err();
1183 assert_eq!(err.status(), StatusCode::FORBIDDEN);
1184 }
1185
1186 #[tokio::test]
1187 async fn check_policy_create_alias_preserves_two_arg_shape() {
1188 struct AuthOnlyCreatePolicy;
1189 impl Policy<Note> for AuthOnlyCreatePolicy {
1190 fn can_create<'a>(&'a self, ctx: &'a PolicyContext) -> BoxFuture<'a, bool> {
1191 Box::pin(async move { ctx.is_authenticated() })
1192 }
1193 }
1194
1195 let state =
1196 crate::AppState::detached().with_forbidden_response(ForbiddenResponse::Forbidden403);
1197 state
1198 .policy_registry()
1199 .register_policy::<Note, _>(AuthOnlyCreatePolicy);
1200
1201 let anon = session_with(None, None);
1202 let err = __check_policy_create::<Note>(&state, &anon)
1203 .await
1204 .unwrap_err();
1205 assert_eq!(err.status(), StatusCode::FORBIDDEN);
1206
1207 let user = session_with(Some("1"), None);
1208 __check_policy_create::<Note>(&state, &user)
1209 .await
1210 .expect("old generated create policy alias remains compatible");
1211 }
1212
1213 #[tokio::test]
1214 async fn check_policy_create_payload_alias_dispatches_payload() {
1215 struct OwnerPayloadPolicy;
1216 impl Policy<Note> for OwnerPayloadPolicy {
1217 fn can_create_payload<'a>(
1218 &'a self,
1219 ctx: &'a PolicyContext,
1220 payload: &'a serde_json::Value,
1221 ) -> BoxFuture<'a, bool> {
1222 Box::pin(async move {
1223 payload.get("author_id").and_then(serde_json::Value::as_i64)
1224 == ctx.user_id_i64()
1225 })
1226 }
1227 }
1228
1229 let state =
1230 crate::AppState::detached().with_forbidden_response(ForbiddenResponse::Forbidden403);
1231 state
1232 .policy_registry()
1233 .register_policy::<Note, _>(OwnerPayloadPolicy);
1234
1235 let user = session_with(Some("1"), None);
1236 let payload = serde_json::json!({"author_id": 1});
1237 __check_policy_create_payload::<Note>(&state, &user, &payload)
1238 .await
1239 .expect("new generated create policy alias passes payload");
1240 }
1241
1242 #[tokio::test]
1243 async fn check_policy_alias_round_trips() {
1244 let state = crate::AppState::detached();
1245 state
1246 .policy_registry()
1247 .register_policy::<Note, _>(AdminOrOwnerPolicy);
1248 let session = session_with(Some("42"), None);
1249 let n = Note { author_id: 42 };
1250 __check_policy::<Note>(&state, &session, "update", &n)
1253 .await
1254 .unwrap();
1255 }
1256
1257 #[tokio::test]
1258 async fn from_request_clones_pool_and_registry_from_state() {
1259 let state = crate::AppState::detached();
1260 state
1261 .policy_registry()
1262 .register_policy::<Note, _>(AdminOrOwnerPolicy);
1263 let session = session_with(Some("7"), Some("admin"));
1264 let ctx = PolicyContext::from_request(&state, &session).await;
1265 assert_eq!(ctx.user_id.as_deref(), Some("7"));
1266 assert!(ctx.has_role("admin"));
1267 assert!(ctx.policy_registry.has_policy::<Note>());
1269 }
1270
1271 #[tokio::test]
1272 async fn scoped_blanket_trait_constructible_without_registered_scope() {
1273 let state = crate::AppState::detached();
1274 let session = session_with(Some("1"), None);
1275 let ctx = PolicyContext::from_request(&state, &session).await;
1276 let _query = Note::scope(&ctx);
1278 assert!(ctx.policy_registry.scope::<Note>().is_none());
1282 }
1283}