use axum::http::{HeaderMap, StatusCode};
use uuid::Uuid;
use crate::{models::Organization, AppState};
async fn verify_org_access(
pool: &sqlx::PgPool,
org_id: Uuid,
user_id: Uuid,
) -> Result<(), StatusCode> {
use crate::models::OrgMember;
let org = Organization::find_by_id(pool, org_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
if org.owner_id == user_id {
return Ok(());
}
let member = OrgMember::find(pool, org_id, user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
if member.is_some() {
Ok(())
} else {
Err(StatusCode::FORBIDDEN)
}
}
#[derive(Debug, Clone)]
pub struct OrgContext {
pub org_id: Uuid,
pub org: Organization,
}
pub async fn resolve_org_context(
state: &AppState,
user_id: Uuid,
headers: &HeaderMap,
request_extensions: Option<&axum::http::Extensions>, ) -> Result<OrgContext, StatusCode> {
let pool = state.db.pool();
let api_token_org_id = request_extensions.and_then(|ext| {
ext.get::<String>()
.and_then(|s| s.strip_prefix("org_id:").and_then(|rest| Uuid::parse_str(rest).ok()))
});
let org = if let Some(org_id) = api_token_org_id {
let org = Organization::find_by_id(pool, org_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
verify_org_access(pool, org_id, user_id)
.await
.map_err(|_| StatusCode::FORBIDDEN)?;
org
} else if let Some(org_id_header) = headers.get("X-Organization-Id") {
let org_id_str = org_id_header.to_str().map_err(|_| StatusCode::BAD_REQUEST)?;
let org_id = Uuid::parse_str(org_id_str).map_err(|_| StatusCode::BAD_REQUEST)?;
let org = Organization::find_by_id(pool, org_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
verify_org_access(pool, org_id, user_id)
.await
.map_err(|_| StatusCode::FORBIDDEN)?;
org
} else if let Some(org_slug_header) = headers.get("X-Organization-Slug") {
let slug = org_slug_header.to_str().map_err(|_| StatusCode::BAD_REQUEST)?;
let org = Organization::find_by_slug(pool, slug)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
verify_org_access(pool, org.id, user_id)
.await
.map_err(|_| StatusCode::FORBIDDEN)?;
org
} else {
let orgs = Organization::find_by_user(pool, user_id)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
orgs.into_iter().find(|o| o.owner_id == user_id).ok_or(StatusCode::NOT_FOUND)?
};
Ok(OrgContext {
org_id: org.id,
org,
})
}