Skip to main content

docbox_http/middleware/
tenant.rs

1//! Extractor for extracting the current tenant from the API headers
2
3use std::sync::Arc;
4
5use crate::error::{DynHttpError, HttpCommonError, HttpError};
6use axum::{
7    Extension,
8    extract::{FromRequestParts, Request},
9    http::{HeaderMap, StatusCode, request::Parts},
10    middleware::Next,
11    response::Response,
12};
13use docbox_core::{
14    database::{DatabasePoolCache, DbPool, models::tenant::Tenant},
15    events::{EventPublisherFactory, TenantEventPublisher},
16    search::{SearchIndexFactory, TenantSearchIndex},
17    storage::{StorageLayer, StorageLayerFactory},
18    tenant::{tenant_cache::TenantCache, tenant_options_ext::TenantOptionsExt},
19};
20use thiserror::Error;
21use tracing::Instrument;
22use utoipa::IntoParams;
23use uuid::Uuid;
24
25// Header for the tenant ID
26pub const TENANT_ID_HEADER: &str = "x-tenant-id";
27// Header for the tenant env
28pub const TENANT_ENV_HEADER: &str = "x-tenant-env";
29
30/// OpenAPI param for requiring the tenant identifier headers
31#[derive(IntoParams)]
32#[into_params(parameter_in = Header)]
33#[allow(unused)]
34pub struct TenantParams {
35    /// ID of the tenant you are targeting
36    #[param(rename = "x-tenant-id")]
37    pub tenant_id: String,
38    /// Environment of the tenant you are targeting
39    #[param(rename = "x-tenant-env")]
40    pub tenant_env: String,
41}
42
43/// Authenticates the requested tenant, loads the tenant from the database and stores it
44/// on the request extensions so it can be extracted by handlers
45pub async fn tenant_auth_middleware(
46    headers: HeaderMap,
47    db_cache: Extension<Arc<DatabasePoolCache>>,
48    tenant_cache: Extension<Arc<TenantCache>>,
49    mut request: Request,
50    next: Next,
51) -> Result<Response, DynHttpError> {
52    // Extract the request tenant
53    let tenant = extract_tenant(&headers, &db_cache, &tenant_cache).await?;
54
55    // Provide a request span that contains the tenant metadata
56    let span = tracing::info_span!("tenant", tenant_id = %tenant.id, tenant_env = %tenant.env);
57
58    // Add the tenant as an extension
59    request.extensions_mut().insert(tenant);
60
61    // Continue the request normally
62    Ok(next.run(request).instrument(span).await)
63}
64
65pub fn get_tenant_env(headers: &HeaderMap) -> Result<String, ExtractTenantError> {
66    match headers.get(TENANT_ENV_HEADER) {
67        Some(value) => value
68            .to_str()
69            .map_err(|_| ExtractTenantError::InvalidTenantEnv)
70            .map(|value| value.to_string()),
71
72        // Tenant not provided
73        None => Err(ExtractTenantError::MissingTenantEnv),
74    }
75}
76
77#[derive(Debug, Error)]
78pub enum ExtractTenantError {
79    #[error("tenant id is required")]
80    MissingTenantId,
81    #[error("tenant id must be a valid uuid")]
82    InvalidTenantId,
83    #[error("tenant env is required")]
84    MissingTenantEnv,
85    #[error("tenant env must be a valid uuid")]
86    InvalidTenantEnv,
87    #[error("tenant not found")]
88    TenantNotFound,
89}
90
91impl HttpError for ExtractTenantError {
92    fn status(&self) -> axum::http::StatusCode {
93        StatusCode::BAD_REQUEST
94    }
95}
96
97/// Extracts the target tenant for the provided request
98pub async fn extract_tenant(
99    headers: &HeaderMap,
100    db_cache: &DatabasePoolCache,
101    tenant_cache: &TenantCache,
102) -> Result<Tenant, DynHttpError> {
103    let tenant_id: Uuid = match headers.get(TENANT_ID_HEADER) {
104        Some(value) => {
105            let value_str = value
106                .to_str()
107                .map_err(|_| ExtractTenantError::InvalidTenantId)?;
108
109            value_str
110                .parse()
111                .map_err(|_| ExtractTenantError::InvalidTenantId)?
112        }
113
114        // Tenant not provided
115        None => return Err(ExtractTenantError::MissingTenantId.into()),
116    };
117
118    let env = get_tenant_env(headers)?;
119
120    let db = db_cache.get_root_pool().await.map_err(|error| {
121        tracing::error!(?error, "failed to connect to root database");
122        HttpCommonError::ServerError
123    })?;
124
125    let tenant = tenant_cache
126        .get_tenant(&db, env, tenant_id)
127        .await
128        .map_err(|error| {
129            tracing::error!(?error, "failed to query root tenant");
130            HttpCommonError::ServerError
131        })?
132        .ok_or(ExtractTenantError::TenantNotFound)?;
133
134    Ok(tenant)
135}
136
137/// Extractor to get database access for the current tenant
138pub struct TenantDb(pub DbPool);
139
140impl<S> FromRequestParts<S> for TenantDb
141where
142    S: Send + Sync,
143{
144    type Rejection = DynHttpError;
145
146    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
147        // Extract current tenant
148        let tenant: &Tenant = parts.extensions.get().ok_or_else(|| {
149            tracing::error!("tenant not available within this scope");
150            HttpCommonError::ServerError
151        })?;
152
153        // Extract database cache
154        let db_cache: &Arc<DatabasePoolCache> = parts.extensions.get().ok_or_else(|| {
155            tracing::error!("database pool caching is missing");
156            HttpCommonError::ServerError
157        })?;
158
159        // Create the database connection pool
160        let db = db_cache.get_tenant_pool(tenant).await.map_err(|error| {
161            tracing::error!(?error, "failed to connect to root database");
162            HttpCommonError::ServerError
163        })?;
164
165        Ok(TenantDb(db))
166    }
167}
168
169/// Tenant open search instance
170pub struct TenantSearch(pub TenantSearchIndex);
171
172impl<S> FromRequestParts<S> for TenantSearch
173where
174    S: Send + Sync,
175{
176    type Rejection = DynHttpError;
177
178    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
179        // Extract current tenant
180        let tenant: &Tenant = parts.extensions.get().ok_or_else(|| {
181            tracing::error!("tenant not available within this scope");
182            HttpCommonError::ServerError
183        })?;
184
185        // Extract search index factory
186        let factory: &SearchIndexFactory = parts.extensions.get().ok_or_else(|| {
187            tracing::error!("search index factory is missing");
188            HttpCommonError::ServerError
189        })?;
190
191        // Create search index
192        let search = factory.create_search_index(tenant);
193
194        Ok(TenantSearch(search))
195    }
196}
197
198/// Tenant storage access
199pub struct TenantStorage(pub StorageLayer);
200
201impl<S> FromRequestParts<S> for TenantStorage
202where
203    S: Send + Sync,
204{
205    type Rejection = DynHttpError;
206
207    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
208        // Extract current tenant
209        let tenant: &Tenant = parts.extensions.get().ok_or_else(|| {
210            tracing::error!("tenant not available within this scope");
211            HttpCommonError::ServerError
212        })?;
213
214        // Extract open search access
215        let factory: &StorageLayerFactory = parts.extensions.get().ok_or_else(|| {
216            tracing::error!("storage layer is missing");
217            HttpCommonError::ServerError
218        })?;
219
220        // Create tenant storage layer
221        let storage = factory.create_layer(tenant.storage_layer_options());
222
223        Ok(TenantStorage(storage))
224    }
225}
226
227/// Tenant events access
228pub struct TenantEvents(pub TenantEventPublisher);
229
230impl<S> FromRequestParts<S> for TenantEvents
231where
232    S: Send + Sync,
233{
234    type Rejection = DynHttpError;
235
236    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
237        // Extract current tenant
238        let tenant: &Tenant = parts.extensions.get().ok_or_else(|| {
239            tracing::error!("tenant not available within this scope");
240            HttpCommonError::ServerError
241        })?;
242
243        // Get the event publisher factor
244        let events: &EventPublisherFactory = parts.extensions.get().ok_or_else(|| {
245            tracing::error!("event publisher layer is missing");
246            HttpCommonError::ServerError
247        })?;
248
249        Ok(TenantEvents(events.create_event_publisher(tenant)))
250    }
251}