docbox_http/middleware/
tenant.rs1use std::sync::Arc;
4
5use crate::error::{DynHttpError, HttpCommonError, HttpError};
6use axum::{
7 Extension,
8 extract::{FromRequestParts, Request},
9 http::{HeaderMap, StatusCode, request::Parts},
10 middleware::Next,
11 response::Response,
12};
13use docbox_core::{
14 database::{DatabasePoolCache, DbPool, models::tenant::Tenant},
15 events::{EventPublisherFactory, TenantEventPublisher},
16 search::{SearchIndexFactory, TenantSearchIndex},
17 storage::{StorageLayer, StorageLayerFactory},
18 tenant::{tenant_cache::TenantCache, tenant_options_ext::TenantOptionsExt},
19};
20use thiserror::Error;
21use tracing::Instrument;
22use utoipa::IntoParams;
23use uuid::Uuid;
24
25pub const TENANT_ID_HEADER: &str = "x-tenant-id";
27pub const TENANT_ENV_HEADER: &str = "x-tenant-env";
29
30#[derive(IntoParams)]
32#[into_params(parameter_in = Header)]
33#[allow(unused)]
34pub struct TenantParams {
35 #[param(rename = "x-tenant-id")]
37 pub tenant_id: String,
38 #[param(rename = "x-tenant-env")]
40 pub tenant_env: String,
41}
42
43pub async fn tenant_auth_middleware(
46 headers: HeaderMap,
47 db_cache: Extension<Arc<DatabasePoolCache>>,
48 tenant_cache: Extension<Arc<TenantCache>>,
49 mut request: Request,
50 next: Next,
51) -> Result<Response, DynHttpError> {
52 let tenant = extract_tenant(&headers, &db_cache, &tenant_cache).await?;
54
55 let span = tracing::info_span!("tenant", tenant_id = %tenant.id, tenant_env = %tenant.env);
57
58 request.extensions_mut().insert(tenant);
60
61 Ok(next.run(request).instrument(span).await)
63}
64
65pub fn get_tenant_env(headers: &HeaderMap) -> Result<String, ExtractTenantError> {
66 match headers.get(TENANT_ENV_HEADER) {
67 Some(value) => value
68 .to_str()
69 .map_err(|_| ExtractTenantError::InvalidTenantEnv)
70 .map(|value| value.to_string()),
71
72 None => Err(ExtractTenantError::MissingTenantEnv),
74 }
75}
76
77#[derive(Debug, Error)]
78pub enum ExtractTenantError {
79 #[error("tenant id is required")]
80 MissingTenantId,
81 #[error("tenant id must be a valid uuid")]
82 InvalidTenantId,
83 #[error("tenant env is required")]
84 MissingTenantEnv,
85 #[error("tenant env must be a valid uuid")]
86 InvalidTenantEnv,
87 #[error("tenant not found")]
88 TenantNotFound,
89}
90
91impl HttpError for ExtractTenantError {
92 fn status(&self) -> axum::http::StatusCode {
93 StatusCode::BAD_REQUEST
94 }
95}
96
97pub async fn extract_tenant(
99 headers: &HeaderMap,
100 db_cache: &DatabasePoolCache,
101 tenant_cache: &TenantCache,
102) -> Result<Tenant, DynHttpError> {
103 let tenant_id: Uuid = match headers.get(TENANT_ID_HEADER) {
104 Some(value) => {
105 let value_str = value
106 .to_str()
107 .map_err(|_| ExtractTenantError::InvalidTenantId)?;
108
109 value_str
110 .parse()
111 .map_err(|_| ExtractTenantError::InvalidTenantId)?
112 }
113
114 None => return Err(ExtractTenantError::MissingTenantId.into()),
116 };
117
118 let env = get_tenant_env(headers)?;
119
120 let db = db_cache.get_root_pool().await.map_err(|error| {
121 tracing::error!(?error, "failed to connect to root database");
122 HttpCommonError::ServerError
123 })?;
124
125 let tenant = tenant_cache
126 .get_tenant(&db, env, tenant_id)
127 .await
128 .map_err(|error| {
129 tracing::error!(?error, "failed to query root tenant");
130 HttpCommonError::ServerError
131 })?
132 .ok_or(ExtractTenantError::TenantNotFound)?;
133
134 Ok(tenant)
135}
136
137pub struct TenantDb(pub DbPool);
139
140impl<S> FromRequestParts<S> for TenantDb
141where
142 S: Send + Sync,
143{
144 type Rejection = DynHttpError;
145
146 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
147 let tenant: &Tenant = parts.extensions.get().ok_or_else(|| {
149 tracing::error!("tenant not available within this scope");
150 HttpCommonError::ServerError
151 })?;
152
153 let db_cache: &Arc<DatabasePoolCache> = parts.extensions.get().ok_or_else(|| {
155 tracing::error!("database pool caching is missing");
156 HttpCommonError::ServerError
157 })?;
158
159 let db = db_cache.get_tenant_pool(tenant).await.map_err(|error| {
161 tracing::error!(?error, "failed to connect to root database");
162 HttpCommonError::ServerError
163 })?;
164
165 Ok(TenantDb(db))
166 }
167}
168
169pub struct TenantSearch(pub TenantSearchIndex);
171
172impl<S> FromRequestParts<S> for TenantSearch
173where
174 S: Send + Sync,
175{
176 type Rejection = DynHttpError;
177
178 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
179 let tenant: &Tenant = parts.extensions.get().ok_or_else(|| {
181 tracing::error!("tenant not available within this scope");
182 HttpCommonError::ServerError
183 })?;
184
185 let factory: &SearchIndexFactory = parts.extensions.get().ok_or_else(|| {
187 tracing::error!("search index factory is missing");
188 HttpCommonError::ServerError
189 })?;
190
191 let search = factory.create_search_index(tenant);
193
194 Ok(TenantSearch(search))
195 }
196}
197
198pub struct TenantStorage(pub StorageLayer);
200
201impl<S> FromRequestParts<S> for TenantStorage
202where
203 S: Send + Sync,
204{
205 type Rejection = DynHttpError;
206
207 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
208 let tenant: &Tenant = parts.extensions.get().ok_or_else(|| {
210 tracing::error!("tenant not available within this scope");
211 HttpCommonError::ServerError
212 })?;
213
214 let factory: &StorageLayerFactory = parts.extensions.get().ok_or_else(|| {
216 tracing::error!("storage layer is missing");
217 HttpCommonError::ServerError
218 })?;
219
220 let storage = factory.create_layer(tenant.storage_layer_options());
222
223 Ok(TenantStorage(storage))
224 }
225}
226
227pub struct TenantEvents(pub TenantEventPublisher);
229
230impl<S> FromRequestParts<S> for TenantEvents
231where
232 S: Send + Sync,
233{
234 type Rejection = DynHttpError;
235
236 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
237 let tenant: &Tenant = parts.extensions.get().ok_or_else(|| {
239 tracing::error!("tenant not available within this scope");
240 HttpCommonError::ServerError
241 })?;
242
243 let events: &EventPublisherFactory = parts.extensions.get().ok_or_else(|| {
245 tracing::error!("event publisher layer is missing");
246 HttpCommonError::ServerError
247 })?;
248
249 Ok(TenantEvents(events.create_event_publisher(tenant)))
250 }
251}