use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use http::StatusCode;
use crate::session::Session;
pub type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
#[derive(Clone)]
pub struct PolicyContext {
pub session: Session,
pub user_id: Option<String>,
pub roles: Vec<String>,
#[cfg(feature = "db")]
pub pool:
Option<diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>>,
pub policy_registry: PolicyRegistry,
}
impl PolicyContext {
pub async fn from_session(session: &Session, auth_session_key: &str) -> Self {
let user_id = session.get(auth_session_key).await;
let role = session.get("role").await;
let roles = role.into_iter().collect();
Self {
session: session.clone(),
user_id,
roles,
#[cfg(feature = "db")]
pool: None,
policy_registry: PolicyRegistry::default(),
}
}
pub async fn from_request(state: &crate::AppState, session: &Session) -> Self {
let mut ctx = Self::from_session(session, state.auth_session_key()).await;
ctx.policy_registry = state.policy_registry().clone();
#[cfg(feature = "db")]
{
if let Some(pool) = state.pool() {
ctx.pool = Some(pool.clone());
}
}
ctx
}
#[must_use]
pub const fn is_authenticated(&self) -> bool {
self.user_id.is_some()
}
#[must_use]
pub fn user_id_i64(&self) -> Option<i64> {
self.user_id.as_deref().and_then(|s| s.parse().ok())
}
#[must_use]
pub fn has_role(&self, role: &str) -> bool {
self.roles.iter().any(|r| r == role)
}
#[must_use]
pub fn has_any_role<I, S>(&self, candidates: I) -> bool
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
candidates.into_iter().any(|c| self.has_role(c.as_ref()))
}
#[cfg(feature = "db")]
#[must_use]
pub fn with_pool(
mut self,
pool: diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>,
) -> Self {
self.pool = Some(pool);
self
}
}
pub trait Policy<R: Send + Sync + 'static>: Send + Sync + 'static {
fn can_show<'a>(&'a self, _ctx: &'a PolicyContext, _resource: &'a R) -> BoxFuture<'a, bool> {
Box::pin(async { false })
}
fn can_create<'a>(&'a self, _ctx: &'a PolicyContext) -> BoxFuture<'a, bool> {
Box::pin(async { false })
}
fn can_create_payload<'a>(
&'a self,
ctx: &'a PolicyContext,
_payload: &'a serde_json::Value,
) -> BoxFuture<'a, bool> {
self.can_create(ctx)
}
fn can_update<'a>(&'a self, _ctx: &'a PolicyContext, _resource: &'a R) -> BoxFuture<'a, bool> {
Box::pin(async { false })
}
fn can_delete<'a>(&'a self, _ctx: &'a PolicyContext, _resource: &'a R) -> BoxFuture<'a, bool> {
Box::pin(async { false })
}
fn can<'a>(
&'a self,
action: &'a str,
ctx: &'a PolicyContext,
resource: &'a R,
) -> BoxFuture<'a, bool> {
Box::pin(async move {
match action {
"show" | "read" => self.can_show(ctx, resource).await,
"create" => self.can_create(ctx).await,
"update" | "edit" => self.can_update(ctx, resource).await,
"delete" | "destroy" => self.can_delete(ctx, resource).await,
_ => false,
}
})
}
}
#[cfg(feature = "db")]
pub trait Scope<R: Send + Sync + 'static>: Send + Sync + 'static {
fn list<'a>(
&'a self,
_ctx: &'a PolicyContext,
_conn: &'a mut diesel_async::AsyncPgConnection,
) -> BoxFuture<'a, crate::AutumnResult<Vec<R>>> {
Box::pin(async { Ok(Vec::new()) })
}
}
#[cfg(not(feature = "db"))]
pub trait Scope<R: Send + Sync + 'static>: Send + Sync + 'static {
fn list<'a>(&'a self, _ctx: &'a PolicyContext) -> BoxFuture<'a, crate::AutumnResult<Vec<R>>> {
Box::pin(async { Ok(Vec::new()) })
}
}
pub struct ScopeQuery<'a, R: Send + Sync + 'static> {
ctx: &'a PolicyContext,
_marker: std::marker::PhantomData<fn() -> R>,
}
#[cfg(feature = "db")]
impl<R: Send + Sync + 'static> ScopeQuery<'_, R> {
pub async fn load(
self,
conn: &mut diesel_async::AsyncPgConnection,
) -> crate::AutumnResult<Vec<R>> {
let scope = self.ctx.policy_registry.scope::<R>().ok_or_else(|| {
crate::AutumnError::from(std::io::Error::other(format!(
"no scope registered for resource type {}",
std::any::type_name::<R>()
)))
.with_status(StatusCode::INTERNAL_SERVER_ERROR)
})?;
scope.list(self.ctx, conn).await
}
}
#[cfg(not(feature = "db"))]
impl<R: Send + Sync + 'static> ScopeQuery<'_, R> {
pub async fn load(self) -> crate::AutumnResult<Vec<R>> {
let scope = self.ctx.policy_registry.scope::<R>().ok_or_else(|| {
crate::AutumnError::from(std::io::Error::other(format!(
"no scope registered for resource type {}",
std::any::type_name::<R>()
)))
.with_status(StatusCode::INTERNAL_SERVER_ERROR)
})?;
scope.list(self.ctx).await
}
}
pub trait Scoped: Send + Sync + Sized + 'static {
#[must_use]
fn scope(ctx: &PolicyContext) -> ScopeQuery<'_, Self> {
ScopeQuery {
ctx,
_marker: std::marker::PhantomData,
}
}
}
impl<T: Send + Sync + 'static> Scoped for T {}
#[derive(Clone, Default)]
pub struct PolicyRegistry {
inner: Arc<RwLock<RegistryInner>>,
}
#[derive(Default)]
struct RegistryInner {
policies: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
scopes: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
}
impl PolicyRegistry {
pub fn register_policy<R, P>(&self, policy: P)
where
R: Send + Sync + 'static,
P: Policy<R>,
{
let mut inner = self.inner.write().expect("policy registry lock poisoned");
let key = TypeId::of::<R>();
assert!(
!inner.policies.contains_key(&key),
"Policy for {} already registered. Multiple policies per resource are not supported.",
std::any::type_name::<R>()
);
let boxed: Arc<dyn Policy<R>> = Arc::new(policy);
inner.policies.insert(key, Arc::new(boxed));
}
pub fn register_scope<R, S>(&self, scope: S)
where
R: Send + Sync + 'static,
S: Scope<R>,
{
let mut inner = self.inner.write().expect("policy registry lock poisoned");
let key = TypeId::of::<R>();
assert!(
!inner.scopes.contains_key(&key),
"Scope for {} already registered. Multiple scopes per resource are not supported.",
std::any::type_name::<R>()
);
let boxed: Arc<dyn Scope<R>> = Arc::new(scope);
inner.scopes.insert(key, Arc::new(boxed));
}
#[must_use]
pub fn policy<R: Send + Sync + 'static>(&self) -> Option<Arc<dyn Policy<R>>> {
let inner = self.inner.read().expect("policy registry lock poisoned");
inner
.policies
.get(&TypeId::of::<R>())
.and_then(|a| a.downcast_ref::<Arc<dyn Policy<R>>>().cloned())
}
#[must_use]
pub fn scope<R: Send + Sync + 'static>(&self) -> Option<Arc<dyn Scope<R>>> {
let inner = self.inner.read().expect("policy registry lock poisoned");
inner
.scopes
.get(&TypeId::of::<R>())
.and_then(|a| a.downcast_ref::<Arc<dyn Scope<R>>>().cloned())
}
#[must_use]
pub fn has_policy<R: Send + Sync + 'static>(&self) -> bool {
self.inner
.read()
.expect("policy registry lock poisoned")
.policies
.contains_key(&TypeId::of::<R>())
}
}
impl std::fmt::Debug for PolicyRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inner = self.inner.read().expect("policy registry lock poisoned");
f.debug_struct("PolicyRegistry")
.field("policies", &inner.policies.len())
.field("scopes", &inner.scopes.len())
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ForbiddenResponse {
Forbidden403,
#[default]
NotFound404,
}
impl ForbiddenResponse {
#[must_use]
pub const fn status(self) -> StatusCode {
match self {
Self::Forbidden403 => StatusCode::FORBIDDEN,
Self::NotFound404 => StatusCode::NOT_FOUND,
}
}
#[must_use]
pub const fn message(self) -> &'static str {
match self {
Self::Forbidden403 => "forbidden",
Self::NotFound404 => "not found",
}
}
#[must_use]
pub fn into_error(self) -> crate::AutumnError {
crate::AutumnError::from(std::io::Error::other(self.message())).with_status(self.status())
}
}
impl std::str::FromStr for ForbiddenResponse {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.trim() {
"403" | "forbidden" | "Forbidden" => Ok(Self::Forbidden403),
"404" | "not_found" | "NotFound" | "" => Ok(Self::NotFound404),
other => Err(format!(
"invalid forbidden_response: {other:?} (expected \"403\" or \"404\")"
)),
}
}
}
impl<'de> serde::Deserialize<'de> for ForbiddenResponse {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = String::deserialize(deserializer)?;
raw.parse().map_err(serde::de::Error::custom)
}
}
pub async fn authorize<R>(
state: &crate::AppState,
session: &Session,
action: &str,
resource: &R,
) -> crate::AutumnResult<()>
where
R: Send + Sync + 'static,
{
let policy = state.policy_registry().policy::<R>().ok_or_else(|| {
crate::AutumnError::from(std::io::Error::other(format!(
"no policy registered for resource type {}",
std::any::type_name::<R>()
)))
.with_status(StatusCode::INTERNAL_SERVER_ERROR)
})?;
let ctx = PolicyContext::from_request(state, session).await;
if policy.can(action, &ctx, resource).await {
Ok(())
} else {
Err(state.forbidden_response().into_error())
}
}
#[doc(hidden)]
pub async fn __check_policy<R>(
state: &crate::AppState,
session: &Session,
action: &str,
resource: &R,
) -> crate::AutumnResult<()>
where
R: Send + Sync + 'static,
{
authorize(state, session, action, resource).await
}
#[doc(hidden)]
pub async fn __check_policy_create<R>(
state: &crate::AppState,
session: &Session,
) -> crate::AutumnResult<()>
where
R: Send + Sync + 'static,
{
authorize_create::<R>(state, session).await
}
#[doc(hidden)]
pub async fn __check_policy_create_payload<R>(
state: &crate::AppState,
session: &Session,
payload: &serde_json::Value,
) -> crate::AutumnResult<()>
where
R: Send + Sync + 'static,
{
authorize_create_payload::<R>(state, session, payload).await
}
pub async fn authorize_create<R>(
state: &crate::AppState,
session: &Session,
) -> crate::AutumnResult<()>
where
R: Send + Sync + 'static,
{
let policy = state.policy_registry().policy::<R>().ok_or_else(|| {
crate::AutumnError::from(std::io::Error::other(format!(
"no policy registered for resource type {}",
std::any::type_name::<R>()
)))
.with_status(StatusCode::INTERNAL_SERVER_ERROR)
})?;
let ctx = PolicyContext::from_request(state, session).await;
if policy.can_create(&ctx).await {
Ok(())
} else {
Err(state.forbidden_response().into_error())
}
}
pub async fn authorize_create_payload<R>(
state: &crate::AppState,
session: &Session,
payload: &serde_json::Value,
) -> crate::AutumnResult<()>
where
R: Send + Sync + 'static,
{
let policy = state.policy_registry().policy::<R>().ok_or_else(|| {
crate::AutumnError::from(std::io::Error::other(format!(
"no policy registered for resource type {}",
std::any::type_name::<R>()
)))
.with_status(StatusCode::INTERNAL_SERVER_ERROR)
})?;
let ctx = PolicyContext::from_request(state, session).await;
if policy.can_create_payload(&ctx, payload).await {
Ok(())
} else {
Err(state.forbidden_response().into_error())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
struct Note {
author_id: i64,
}
#[derive(Default)]
struct AdminOrOwnerPolicy;
impl Policy<Note> for AdminOrOwnerPolicy {
fn can_show<'a>(&'a self, _ctx: &'a PolicyContext, _note: &'a Note) -> BoxFuture<'a, bool> {
Box::pin(async { true })
}
fn can_update<'a>(&'a self, ctx: &'a PolicyContext, note: &'a Note) -> BoxFuture<'a, bool> {
Box::pin(
async move { ctx.has_role("admin") || ctx.user_id_i64() == Some(note.author_id) },
)
}
fn can_delete<'a>(
&'a self,
ctx: &'a PolicyContext,
_note: &'a Note,
) -> BoxFuture<'a, bool> {
Box::pin(async move { ctx.has_role("admin") })
}
}
fn ctx(user_id: Option<&str>, role: Option<&str>) -> PolicyContext {
let session = Session::new_for_test(String::new(), HashMap::new());
PolicyContext {
session,
user_id: user_id.map(str::to_owned),
roles: role.into_iter().map(str::to_owned).collect(),
#[cfg(feature = "db")]
pool: None,
policy_registry: PolicyRegistry::default(),
}
}
#[tokio::test]
async fn default_impls_deny() {
struct EmptyPolicy;
impl Policy<Note> for EmptyPolicy {}
let policy = EmptyPolicy;
let c = ctx(Some("1"), None);
let n = Note { author_id: 1 };
assert!(!policy.can_show(&c, &n).await);
assert!(!policy.can_create(&c).await);
assert!(!policy.can_update(&c, &n).await);
assert!(!policy.can_delete(&c, &n).await);
assert!(!policy.can("publish", &c, &n).await);
}
#[tokio::test]
async fn owner_can_update() {
let policy = AdminOrOwnerPolicy;
let c = ctx(Some("42"), None);
let n = Note { author_id: 42 };
assert!(policy.can_update(&c, &n).await);
assert!(!policy.can_delete(&c, &n).await);
}
#[tokio::test]
async fn non_owner_cannot_update() {
let policy = AdminOrOwnerPolicy;
let c = ctx(Some("99"), None);
let n = Note { author_id: 42 };
assert!(!policy.can_update(&c, &n).await);
}
#[tokio::test]
async fn admin_can_delete() {
let policy = AdminOrOwnerPolicy;
let c = ctx(Some("99"), Some("admin"));
let n = Note { author_id: 42 };
assert!(policy.can_delete(&c, &n).await);
}
#[tokio::test]
async fn can_dispatches_named_actions() {
let policy = AdminOrOwnerPolicy;
let c = ctx(Some("42"), None);
let n = Note { author_id: 42 };
assert!(policy.can("show", &c, &n).await);
assert!(policy.can("update", &c, &n).await);
assert!(policy.can("edit", &c, &n).await);
assert!(!policy.can("publish", &c, &n).await);
}
#[test]
fn policy_registry_stores_and_resolves() {
let registry = PolicyRegistry::default();
registry.register_policy::<Note, _>(AdminOrOwnerPolicy);
assert!(registry.has_policy::<Note>());
assert!(registry.policy::<Note>().is_some());
assert!(registry.scope::<Note>().is_none());
}
#[test]
#[should_panic(expected = "already registered")]
fn double_policy_registration_panics() {
let registry = PolicyRegistry::default();
registry.register_policy::<Note, _>(AdminOrOwnerPolicy);
registry.register_policy::<Note, _>(AdminOrOwnerPolicy);
}
#[test]
fn forbidden_response_default_is_404() {
let resp = ForbiddenResponse::default();
assert_eq!(resp, ForbiddenResponse::NotFound404);
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
}
#[test]
fn forbidden_response_parses_strings() {
assert_eq!(
"403".parse::<ForbiddenResponse>().unwrap(),
ForbiddenResponse::Forbidden403
);
assert_eq!(
"404".parse::<ForbiddenResponse>().unwrap(),
ForbiddenResponse::NotFound404
);
assert_eq!(
"forbidden".parse::<ForbiddenResponse>().unwrap(),
ForbiddenResponse::Forbidden403
);
assert!("418".parse::<ForbiddenResponse>().is_err());
}
#[test]
fn policy_context_helpers() {
let c = ctx(Some("42"), Some("editor"));
assert!(c.is_authenticated());
assert_eq!(c.user_id_i64(), Some(42));
assert!(c.has_role("editor"));
assert!(!c.has_role("admin"));
assert!(c.has_any_role(["admin", "editor"]));
assert!(!c.has_any_role(["viewer", "guest"]));
}
#[test]
fn anonymous_context_is_not_authenticated() {
let c = ctx(None, None);
assert!(!c.is_authenticated());
assert!(c.user_id_i64().is_none());
assert!(!c.has_role("admin"));
assert!(!c.has_any_role(["admin", "editor"]));
}
#[test]
fn user_id_i64_handles_non_numeric_session_value() {
let c = ctx(Some("not-a-number"), None);
assert!(c.user_id_i64().is_none());
}
#[test]
fn forbidden_response_status_and_message_round_trip() {
assert_eq!(
ForbiddenResponse::Forbidden403.status(),
StatusCode::FORBIDDEN
);
assert_eq!(
ForbiddenResponse::NotFound404.status(),
StatusCode::NOT_FOUND
);
assert_eq!(ForbiddenResponse::Forbidden403.message(), "forbidden");
assert_eq!(ForbiddenResponse::NotFound404.message(), "not found");
}
#[test]
fn forbidden_response_into_error_carries_status_and_message() {
let err = ForbiddenResponse::NotFound404.into_error();
assert_eq!(err.status(), StatusCode::NOT_FOUND);
assert_eq!(err.to_string(), "not found");
let err = ForbiddenResponse::Forbidden403.into_error();
assert_eq!(err.status(), StatusCode::FORBIDDEN);
assert_eq!(err.to_string(), "forbidden");
}
#[test]
fn forbidden_response_parses_empty_string_as_default_404() {
assert_eq!(
"".parse::<ForbiddenResponse>().unwrap(),
ForbiddenResponse::NotFound404
);
assert_eq!(
"not_found".parse::<ForbiddenResponse>().unwrap(),
ForbiddenResponse::NotFound404
);
assert_eq!(
"NotFound".parse::<ForbiddenResponse>().unwrap(),
ForbiddenResponse::NotFound404
);
assert_eq!(
"Forbidden".parse::<ForbiddenResponse>().unwrap(),
ForbiddenResponse::Forbidden403
);
}
#[test]
fn forbidden_response_parse_error_carries_input_value() {
let err = "418".parse::<ForbiddenResponse>().unwrap_err();
assert!(err.contains("418"));
assert!(err.contains("403"));
assert!(err.contains("404"));
}
#[test]
fn forbidden_response_deserializes_from_toml() {
#[derive(Debug, serde::Deserialize)]
struct Holder {
value: ForbiddenResponse,
}
let h: Holder = toml::from_str(r#"value = "403""#).unwrap();
assert_eq!(h.value, ForbiddenResponse::Forbidden403);
let h: Holder = toml::from_str(r#"value = "404""#).unwrap();
assert_eq!(h.value, ForbiddenResponse::NotFound404);
let err = toml::from_str::<Holder>(r#"value = "418""#).unwrap_err();
assert!(err.to_string().contains("418"));
}
#[test]
fn registry_scope_double_registration_panics_with_clear_message() {
let registry = PolicyRegistry::default();
registry.register_scope::<Note, _>(EmptyScope);
let panicked = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
registry.register_scope::<Note, _>(EmptyScope);
}))
.unwrap_err();
let msg = panicked
.downcast_ref::<String>()
.map(String::as_str)
.or_else(|| panicked.downcast_ref::<&'static str>().copied())
.unwrap_or("");
assert!(
msg.contains("already registered"),
"expected double-registration panic, got {msg:?}"
);
}
struct OtherResource;
struct OtherPolicy;
impl Policy<OtherResource> for OtherPolicy {}
struct ThirdResource;
struct EmptyScope;
impl Scope<Note> for EmptyScope {}
#[test]
fn registry_resolves_distinct_resource_types_independently() {
let registry = PolicyRegistry::default();
registry.register_policy::<Note, _>(AdminOrOwnerPolicy);
registry.register_policy::<OtherResource, _>(OtherPolicy);
assert!(registry.has_policy::<Note>());
assert!(registry.has_policy::<OtherResource>());
assert!(!registry.has_policy::<ThirdResource>());
assert!(registry.scope::<Note>().is_none());
}
#[test]
fn registry_debug_shows_counts() {
let registry = PolicyRegistry::default();
registry.register_policy::<Note, _>(AdminOrOwnerPolicy);
registry.register_scope::<Note, _>(EmptyScope);
let dbg = format!("{registry:?}");
assert!(dbg.contains("PolicyRegistry"));
assert!(dbg.contains("policies"));
assert!(dbg.contains("scopes"));
}
fn detached_state_with(
_registry: PolicyRegistry,
forbidden: ForbiddenResponse,
) -> crate::AppState {
crate::AppState::detached()
.with_forbidden_response(forbidden)
.with_auth_session_key("user_id")
}
fn session_with(user_id: Option<&str>, role: Option<&str>) -> Session {
let mut data = HashMap::new();
if let Some(u) = user_id {
data.insert("user_id".to_owned(), u.to_owned());
}
if let Some(r) = role {
data.insert("role".to_owned(), r.to_owned());
}
Session::new_for_test(String::new(), data)
}
#[tokio::test]
async fn authorize_returns_500_when_no_policy_registered() {
let state = detached_state_with(PolicyRegistry::default(), ForbiddenResponse::default());
let session = session_with(Some("42"), None);
let n = Note { author_id: 42 };
let err = authorize::<Note>(&state, &session, "update", &n)
.await
.unwrap_err();
assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn authorize_returns_configured_deny_when_policy_denies() {
let registry = PolicyRegistry::default();
registry.register_policy::<Note, _>(AdminOrOwnerPolicy);
let state = detached_state_with(registry.clone(), ForbiddenResponse::Forbidden403);
let live = state.policy_registry();
live.register_policy::<Note, _>(AdminOrOwnerPolicy);
let session = session_with(Some("99"), None); let n = Note { author_id: 42 };
let err = authorize::<Note>(&state, &session, "update", &n)
.await
.unwrap_err();
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn authorize_returns_ok_when_policy_allows() {
let state = crate::AppState::detached();
state
.policy_registry()
.register_policy::<Note, _>(AdminOrOwnerPolicy);
let session = session_with(Some("42"), None); let n = Note { author_id: 42 };
authorize::<Note>(&state, &session, "update", &n)
.await
.expect("owner is allowed to update");
}
#[tokio::test]
async fn authorize_create_returns_500_when_no_policy_registered() {
let state = crate::AppState::detached();
let session = session_with(Some("42"), None);
let err = authorize_create::<Note>(&state, &session)
.await
.unwrap_err();
assert_eq!(err.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[tokio::test]
async fn authorize_create_dispatches_can_create() {
struct AuthOnlyCreatePolicy;
impl Policy<Note> for AuthOnlyCreatePolicy {
fn can_create<'a>(&'a self, ctx: &'a PolicyContext) -> BoxFuture<'a, bool> {
Box::pin(async move { ctx.is_authenticated() })
}
}
let state =
crate::AppState::detached().with_forbidden_response(ForbiddenResponse::Forbidden403);
state
.policy_registry()
.register_policy::<Note, _>(AuthOnlyCreatePolicy);
let anon = session_with(None, None);
let err = authorize_create::<Note>(&state, &anon).await.unwrap_err();
assert_eq!(err.status(), StatusCode::FORBIDDEN);
let user = session_with(Some("1"), None);
authorize_create::<Note>(&state, &user)
.await
.expect("authenticated user passes can_create");
}
#[tokio::test]
async fn authorize_create_payload_dispatches_can_create_payload() {
struct OwnerPayloadPolicy;
impl Policy<Note> for OwnerPayloadPolicy {
fn can_create_payload<'a>(
&'a self,
ctx: &'a PolicyContext,
payload: &'a serde_json::Value,
) -> BoxFuture<'a, bool> {
Box::pin(async move {
payload.get("author_id").and_then(serde_json::Value::as_i64)
== ctx.user_id_i64()
})
}
}
let state =
crate::AppState::detached().with_forbidden_response(ForbiddenResponse::Forbidden403);
state
.policy_registry()
.register_policy::<Note, _>(OwnerPayloadPolicy);
let user = session_with(Some("1"), None);
let own_payload = serde_json::json!({"author_id": 1});
authorize_create_payload::<Note>(&state, &user, &own_payload)
.await
.expect("owner payload passes can_create_payload");
let other_payload = serde_json::json!({"author_id": 2});
let err = authorize_create_payload::<Note>(&state, &user, &other_payload)
.await
.unwrap_err();
assert_eq!(err.status(), StatusCode::FORBIDDEN);
}
#[tokio::test]
async fn check_policy_create_alias_preserves_two_arg_shape() {
struct AuthOnlyCreatePolicy;
impl Policy<Note> for AuthOnlyCreatePolicy {
fn can_create<'a>(&'a self, ctx: &'a PolicyContext) -> BoxFuture<'a, bool> {
Box::pin(async move { ctx.is_authenticated() })
}
}
let state =
crate::AppState::detached().with_forbidden_response(ForbiddenResponse::Forbidden403);
state
.policy_registry()
.register_policy::<Note, _>(AuthOnlyCreatePolicy);
let anon = session_with(None, None);
let err = __check_policy_create::<Note>(&state, &anon)
.await
.unwrap_err();
assert_eq!(err.status(), StatusCode::FORBIDDEN);
let user = session_with(Some("1"), None);
__check_policy_create::<Note>(&state, &user)
.await
.expect("old generated create policy alias remains compatible");
}
#[tokio::test]
async fn check_policy_create_payload_alias_dispatches_payload() {
struct OwnerPayloadPolicy;
impl Policy<Note> for OwnerPayloadPolicy {
fn can_create_payload<'a>(
&'a self,
ctx: &'a PolicyContext,
payload: &'a serde_json::Value,
) -> BoxFuture<'a, bool> {
Box::pin(async move {
payload.get("author_id").and_then(serde_json::Value::as_i64)
== ctx.user_id_i64()
})
}
}
let state =
crate::AppState::detached().with_forbidden_response(ForbiddenResponse::Forbidden403);
state
.policy_registry()
.register_policy::<Note, _>(OwnerPayloadPolicy);
let user = session_with(Some("1"), None);
let payload = serde_json::json!({"author_id": 1});
__check_policy_create_payload::<Note>(&state, &user, &payload)
.await
.expect("new generated create policy alias passes payload");
}
#[tokio::test]
async fn check_policy_alias_round_trips() {
let state = crate::AppState::detached();
state
.policy_registry()
.register_policy::<Note, _>(AdminOrOwnerPolicy);
let session = session_with(Some("42"), None);
let n = Note { author_id: 42 };
__check_policy::<Note>(&state, &session, "update", &n)
.await
.unwrap();
}
#[tokio::test]
async fn from_request_clones_pool_and_registry_from_state() {
let state = crate::AppState::detached();
state
.policy_registry()
.register_policy::<Note, _>(AdminOrOwnerPolicy);
let session = session_with(Some("7"), Some("admin"));
let ctx = PolicyContext::from_request(&state, &session).await;
assert_eq!(ctx.user_id.as_deref(), Some("7"));
assert!(ctx.has_role("admin"));
assert!(ctx.policy_registry.has_policy::<Note>());
}
#[tokio::test]
async fn scoped_blanket_trait_constructible_without_registered_scope() {
let state = crate::AppState::detached();
let session = session_with(Some("1"), None);
let ctx = PolicyContext::from_request(&state, &session).await;
let _query = Note::scope(&ctx);
assert!(ctx.policy_registry.scope::<Note>().is_none());
}
}