Skip to main content

mockforge_registry_server/middleware/
org_context.rs

1//! Organization context middleware for multi-tenancy
2//!
3//! This middleware extracts organization context from requests and provides
4//! it to handlers via extractors. Supports:
5//! - X-Organization-Id header
6//! - X-Organization-Slug header
7//! - Default org from user's personal org
8
9use axum::http::{HeaderMap, StatusCode};
10use uuid::Uuid;
11
12use crate::{models::Organization, AppState};
13
14/// Verify user has access to organization
15async fn verify_org_access(
16    pool: &sqlx::PgPool,
17    org_id: Uuid,
18    user_id: Uuid,
19) -> Result<(), StatusCode> {
20    use crate::models::OrgMember;
21
22    // Check if user is owner
23    let org = Organization::find_by_id(pool, org_id)
24        .await
25        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
26        .ok_or(StatusCode::NOT_FOUND)?;
27
28    if org.owner_id == user_id {
29        return Ok(());
30    }
31
32    // Check if user is a member
33    let member = OrgMember::find(pool, org_id, user_id)
34        .await
35        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
36
37    if member.is_some() {
38        Ok(())
39    } else {
40        Err(StatusCode::FORBIDDEN)
41    }
42}
43
44/// Organization context
45#[derive(Debug, Clone)]
46pub struct OrgContext {
47    pub org_id: Uuid,
48    pub org: Organization,
49}
50
51/// Helper function to resolve org context from State and AuthUser
52/// Use this in handlers instead of the extractor if you need more control
53///
54/// Also checks request extensions for org_id set by API token auth
55pub async fn resolve_org_context(
56    state: &AppState,
57    user_id: Uuid,
58    headers: &HeaderMap,
59    request_extensions: Option<&axum::http::Extensions>, // Optional extensions from request
60) -> Result<OrgContext, StatusCode> {
61    let pool = state.db.pool();
62
63    // Check if org_id was set by API token auth (for faster lookup)
64    let api_token_org_id = request_extensions.and_then(|ext| {
65        // Try to get org_id from extensions
66        ext.get::<String>()
67            .and_then(|s| s.strip_prefix("org_id:").and_then(|rest| Uuid::parse_str(rest).ok()))
68    });
69
70    // Try to get org from API token first, then header, then default
71    let org = if let Some(org_id) = api_token_org_id {
72        // Use org_id from API token (fastest path)
73        // Try cache first
74        let org = Organization::find_by_id(pool, org_id)
75            .await
76            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
77            .ok_or(StatusCode::NOT_FOUND)?;
78
79        // Verify user has access (already verified by API token, but double-check)
80        verify_org_access(pool, org_id, user_id)
81            .await
82            .map_err(|_| StatusCode::FORBIDDEN)?;
83
84        org
85    } else if let Some(org_id_header) = headers.get("X-Organization-Id") {
86        // Resolve by ID
87        let org_id_str = org_id_header.to_str().map_err(|_| StatusCode::BAD_REQUEST)?;
88        let org_id = Uuid::parse_str(org_id_str).map_err(|_| StatusCode::BAD_REQUEST)?;
89
90        // Try cache first
91        let org = Organization::find_by_id(pool, org_id)
92            .await
93            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
94            .ok_or(StatusCode::NOT_FOUND)?;
95
96        // Verify user has access to this org
97        verify_org_access(pool, org_id, user_id)
98            .await
99            .map_err(|_| StatusCode::FORBIDDEN)?;
100
101        org
102    } else if let Some(org_slug_header) = headers.get("X-Organization-Slug") {
103        // Resolve by slug
104        let slug = org_slug_header.to_str().map_err(|_| StatusCode::BAD_REQUEST)?;
105
106        let org = Organization::find_by_slug(pool, slug)
107            .await
108            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
109            .ok_or(StatusCode::NOT_FOUND)?;
110
111        // Verify user has access to this org
112        verify_org_access(pool, org.id, user_id)
113            .await
114            .map_err(|_| StatusCode::FORBIDDEN)?;
115
116        org
117    } else {
118        // Get user's default/personal org
119        // For now, get the first org where user is owner
120        // In the future, we might store a "default_org_id" on users
121        let orgs = Organization::find_by_user(pool, user_id)
122            .await
123            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
124
125        orgs.into_iter().find(|o| o.owner_id == user_id).ok_or(StatusCode::NOT_FOUND)?
126    };
127
128    Ok(OrgContext {
129        org_id: org.id,
130        org,
131    })
132}