use super::claims::OrgClaims;
use super::extractors::AuthenticatedUserId;
use crate::error::TidewayError;
use crate::organizations::storage::MembershipStore;
use axum::{extract::Request, middleware::Next, response::Response};
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
pub type MiddlewareFuture = Pin<Box<dyn Future<Output = Result<Response, TidewayError>> + Send>>;
struct OrgContext<M: MembershipStore> {
store: M,
org_id: String,
user_id: String,
}
impl<M: MembershipStore + Clone + 'static> OrgContext<M> {
fn from_request(request: &Request) -> Result<Self, TidewayError> {
let store = request.extensions().get::<M>().cloned().ok_or_else(|| {
TidewayError::internal("MembershipStore not found in request extensions")
})?;
let org_claims = request
.extensions()
.get::<OrgClaims>()
.ok_or_else(|| TidewayError::unauthorized("No organization context in token"))?;
let user_id = request
.extensions()
.get::<AuthenticatedUserId>()
.ok_or_else(|| TidewayError::unauthorized("User not authenticated"))?;
Ok(Self {
store,
org_id: org_claims.org_id.clone(),
user_id: user_id.0.clone(),
})
}
async fn check_membership(&self) -> Result<(), TidewayError> {
let is_member = self
.store
.is_member(&self.org_id, &self.user_id)
.await
.map_err(|e| TidewayError::internal(format!("Failed to check membership: {e}")))?;
if !is_member {
return Err(TidewayError::forbidden("Not a member of this organization"));
}
Ok(())
}
async fn get_membership_and_role(&self) -> Result<(M::Membership, M::Role), TidewayError> {
let membership = self
.store
.get_membership(&self.org_id, &self.user_id)
.await
.map_err(|e| TidewayError::internal(format!("Failed to get membership: {e}")))?
.ok_or_else(|| TidewayError::forbidden("Not a member of this organization"))?;
let role = self.store.membership_role(&membership);
Ok((membership, role))
}
}
pub struct RequireOrgMembership<M: MembershipStore> {
_store: PhantomData<M>,
}
impl<M: MembershipStore + Clone + 'static> RequireOrgMembership<M> {
pub async fn middleware(request: Request, next: Next) -> Result<Response, TidewayError> {
let ctx = OrgContext::<M>::from_request(&request)?;
ctx.check_membership().await?;
Ok(next.run(request).await)
}
}
#[derive(Clone)]
pub struct OrgStoreLayer<M: MembershipStore> {
store: M,
}
impl<M: MembershipStore + Clone + 'static> OrgStoreLayer<M> {
#[must_use]
pub fn new(store: M) -> Self {
Self { store }
}
pub async fn middleware(&self, mut request: Request, next: Next) -> Response {
request.extensions_mut().insert(self.store.clone());
next.run(request).await
}
}
pub struct RequirePermission<M: MembershipStore> {
_store: PhantomData<M>,
}
impl<M> RequirePermission<M>
where
M: MembershipStore + Clone + 'static,
{
pub fn check<F>(
check: F,
) -> impl Fn(Request, Next) -> MiddlewareFuture + Clone + Send + Sync + 'static
where
F: Fn(&M, &M::Role) -> bool + Clone + Send + Sync + 'static,
{
move |request: Request, next: Next| {
let check = check.clone();
Box::pin(async move {
let ctx = OrgContext::<M>::from_request(&request)?;
let (_, role) = ctx.get_membership_and_role().await?;
if !check(&ctx.store, &role) {
return Err(TidewayError::forbidden("Insufficient permissions"));
}
Ok(next.run(request).await)
})
}
}
pub fn can_manage_members()
-> impl Fn(Request, Next) -> MiddlewareFuture + Clone + Send + Sync + 'static {
Self::check(|store, role| store.can_manage_members(role))
}
pub fn can_manage_settings()
-> impl Fn(Request, Next) -> MiddlewareFuture + Clone + Send + Sync + 'static {
Self::check(|store, role| store.can_manage_settings(role))
}
pub fn can_delete_org()
-> impl Fn(Request, Next) -> MiddlewareFuture + Clone + Send + Sync + 'static {
Self::check(|store, role| store.can_delete_org(role))
}
pub fn is_owner() -> impl Fn(Request, Next) -> MiddlewareFuture + Clone + Send + Sync + 'static
{
Self::check(|store, role| store.is_owner(role))
}
}