use super::{MultiTenantWorkspaceRegistry, TenantWorkspace};
use crate::{Error, Result};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct WorkspaceContext {
pub workspace_id: String,
pub original_path: String,
pub stripped_path: String,
pub workspace: TenantWorkspace,
}
#[derive(Debug, Clone)]
pub struct WorkspaceRouter {
registry: Arc<MultiTenantWorkspaceRegistry>,
}
impl WorkspaceRouter {
pub fn new(registry: Arc<MultiTenantWorkspaceRegistry>) -> Self {
Self { registry }
}
pub fn extract_workspace_context(&self, path: &str) -> Result<WorkspaceContext> {
let config = self.registry.config();
if !config.enabled {
let workspace = self.registry.get_default_workspace()?;
return Ok(WorkspaceContext {
workspace_id: config.default_workspace.clone(),
original_path: path.to_string(),
stripped_path: path.to_string(),
workspace,
});
}
if let Some(workspace_id) = self.registry.extract_workspace_id_from_path(path) {
let workspace = self.registry.get_workspace(&workspace_id)?;
if !workspace.enabled {
return Err(Error::internal(format!("Workspace '{}' is disabled", workspace_id)));
}
let stripped_path = self.registry.strip_workspace_prefix(path, &workspace_id);
Ok(WorkspaceContext {
workspace_id: workspace_id.clone(),
original_path: path.to_string(),
stripped_path,
workspace,
})
} else {
let workspace = self.registry.get_default_workspace()?;
Ok(WorkspaceContext {
workspace_id: config.default_workspace.clone(),
original_path: path.to_string(),
stripped_path: path.to_string(),
workspace,
})
}
}
pub fn registry(&self) -> &Arc<MultiTenantWorkspaceRegistry> {
&self.registry
}
pub fn get_workspace(&self, workspace_id: &str) -> Result<TenantWorkspace> {
self.registry.get_workspace(workspace_id)
}
pub fn is_multi_tenant_enabled(&self) -> bool {
self.registry.config().enabled
}
pub fn workspace_prefix(&self) -> &str {
&self.registry.config().workspace_prefix
}
}
pub mod axum_middleware {
use super::*;
use ::axum::http::StatusCode;
use ::axum::{
extract::Request,
middleware::Next,
response::{IntoResponse, Response},
};
pub async fn workspace_middleware(
router: Arc<WorkspaceRouter>,
mut request: Request,
next: Next,
) -> Response {
let path = request.uri().path();
let context = match router.extract_workspace_context(path) {
Ok(ctx) => ctx,
Err(e) => {
return (StatusCode::NOT_FOUND, format!("Workspace error: {}", e)).into_response();
}
};
request.extensions_mut().insert(context.clone());
if context.original_path != context.stripped_path {
let mut parts = request.uri().clone().into_parts();
parts.path_and_query = context.stripped_path.parse().ok().or(parts.path_and_query);
if let Ok(uri) = ::axum::http::Uri::from_parts(parts) {
*request.uri_mut() = uri;
}
}
next.run(request).await
}
pub trait WorkspaceContextExt {
fn workspace_context(&self) -> Option<&WorkspaceContext>;
fn workspace_id(&self) -> Option<&str>;
fn stripped_path(&self) -> Option<&str>;
}
impl WorkspaceContextExt for Request {
fn workspace_context(&self) -> Option<&WorkspaceContext> {
self.extensions().get::<WorkspaceContext>()
}
fn workspace_id(&self) -> Option<&str> {
self.workspace_context().map(|ctx| ctx.workspace_id.as_str())
}
fn stripped_path(&self) -> Option<&str> {
self.workspace_context().map(|ctx| ctx.stripped_path.as_str())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::multi_tenant::{MultiTenantConfig, MultiTenantWorkspaceRegistry};
use crate::workspace::Workspace;
fn create_test_router() -> WorkspaceRouter {
let config = MultiTenantConfig {
enabled: true,
..Default::default()
};
let mut registry = MultiTenantWorkspaceRegistry::new(config);
let default_ws = Workspace::new("Default".to_string());
registry.register_workspace("default".to_string(), default_ws).unwrap();
let test_ws = Workspace::new("Test Workspace".to_string());
registry.register_workspace("test".to_string(), test_ws).unwrap();
WorkspaceRouter::new(Arc::new(registry))
}
#[test]
fn test_extract_workspace_context_with_prefix() {
let router = create_test_router();
let context = router.extract_workspace_context("/workspace/test/api/users").unwrap();
assert_eq!(context.workspace_id, "test");
assert_eq!(context.original_path, "/workspace/test/api/users");
assert_eq!(context.stripped_path, "/api/users");
assert_eq!(context.workspace.name(), "Test Workspace");
}
#[test]
fn test_extract_workspace_context_default() {
let router = create_test_router();
let context = router.extract_workspace_context("/api/users").unwrap();
assert_eq!(context.workspace_id, "default");
assert_eq!(context.original_path, "/api/users");
assert_eq!(context.stripped_path, "/api/users");
assert_eq!(context.workspace.name(), "Default");
}
#[test]
fn test_extract_workspace_context_nonexistent() {
let router = create_test_router();
let result = router.extract_workspace_context("/workspace/nonexistent/api/users");
assert!(result.is_err());
}
#[test]
fn test_multi_tenant_disabled() {
let config = MultiTenantConfig {
enabled: false,
..Default::default()
};
let mut registry = MultiTenantWorkspaceRegistry::new(config);
let default_ws = Workspace::new("Default".to_string());
registry.register_workspace("default".to_string(), default_ws).unwrap();
let router = WorkspaceRouter::new(Arc::new(registry));
let context = router.extract_workspace_context("/workspace/test/api/users").unwrap();
assert_eq!(context.workspace_id, "default");
assert_eq!(context.stripped_path, "/workspace/test/api/users");
}
}