use std::collections::HashMap;
use std::sync::Arc;
use serde_json;
use crate::errors::HttpError;
use crate::middleware::v2::{Middleware, Next, NextFuture};
use crate::request::ElifRequest;
use crate::response::ElifResponse;
use elif_core::container::{IocContainer, ScopeId};
pub trait IocMiddleware: Middleware {
fn from_ioc_container(
container: &IocContainer,
scope: Option<&ScopeId>,
) -> Result<Self, String>
where
Self: Sized;
}
pub struct IocMiddlewareFactory<M> {
_phantom: std::marker::PhantomData<M>,
}
impl<M> IocMiddlewareFactory<M> {
pub fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
impl<M> Default for IocMiddlewareFactory<M> {
fn default() -> Self {
Self::new()
}
}
impl<M> IocMiddlewareFactory<M>
where
M: IocMiddleware + 'static,
{
pub fn create(
&self,
container: &IocContainer,
scope: Option<&ScopeId>,
) -> Result<M, HttpError> {
M::from_ioc_container(container, scope).map_err(|e| HttpError::InternalError {
message: format!("Failed to create middleware: {}", e),
})
}
}
pub struct MiddlewareRegistry {
factories: HashMap<String, Box<dyn MiddlewareFactory>>,
container: Arc<IocContainer>,
}
pub trait MiddlewareFactory: Send + Sync {
fn create_middleware(
&self,
container: &IocContainer,
scope: Option<&ScopeId>,
) -> Result<Arc<dyn Middleware>, HttpError>;
}
impl<M> MiddlewareFactory for IocMiddlewareFactory<M>
where
M: IocMiddleware + 'static,
{
fn create_middleware(
&self,
container: &IocContainer,
scope: Option<&ScopeId>,
) -> Result<Arc<dyn Middleware>, HttpError> {
let middleware = self.create(container, scope)?;
Ok(Arc::new(middleware))
}
}
impl MiddlewareRegistry {
pub fn new(container: Arc<IocContainer>) -> Self {
Self {
factories: HashMap::new(),
container,
}
}
pub fn register<M>(&mut self, name: &str) -> Result<(), HttpError>
where
M: IocMiddleware + 'static,
{
let factory = Box::new(IocMiddlewareFactory::<M>::new());
self.factories.insert(name.to_string(), factory);
Ok(())
}
pub fn register_factory(&mut self, name: &str, factory: Box<dyn MiddlewareFactory>) {
self.factories.insert(name.to_string(), factory);
}
pub fn create_middleware(
&self,
name: &str,
scope: Option<&ScopeId>,
) -> Result<Arc<dyn Middleware>, HttpError> {
let factory = self
.factories
.get(name)
.ok_or_else(|| HttpError::InternalError {
message: format!("Middleware '{}' not registered", name),
})?;
factory.create_middleware(&self.container, scope)
}
pub fn create_middleware_pipeline(
&self,
names: &[&str],
scope: Option<&ScopeId>,
) -> Result<Vec<Arc<dyn Middleware>>, HttpError> {
names
.iter()
.map(|name| self.create_middleware(name, scope))
.collect()
}
pub fn registered_middleware(&self) -> Vec<String> {
self.factories.keys().cloned().collect()
}
}
pub struct MiddlewareRegistryBuilder {
container: Option<Arc<IocContainer>>,
middleware: Vec<(String, Box<dyn MiddlewareFactory>)>,
}
impl MiddlewareRegistryBuilder {
pub fn new() -> Self {
Self {
container: None,
middleware: Vec::new(),
}
}
pub fn container(mut self, container: Arc<IocContainer>) -> Self {
self.container = Some(container);
self
}
pub fn register<M>(mut self, name: &str) -> Self
where
M: IocMiddleware + 'static,
{
let factory = Box::new(IocMiddlewareFactory::<M>::new());
self.middleware.push((name.to_string(), factory));
self
}
pub fn register_factory(mut self, name: &str, factory: Box<dyn MiddlewareFactory>) -> Self {
self.middleware.push((name.to_string(), factory));
self
}
pub fn build(self) -> Result<MiddlewareRegistry, HttpError> {
let container = self.container.ok_or_else(|| HttpError::InternalError {
message: "IoC container is required for middleware registry".to_string(),
})?;
let mut registry = MiddlewareRegistry::new(container);
for (name, factory) in self.middleware {
registry.register_factory(&name, factory);
}
Ok(registry)
}
}
impl Default for MiddlewareRegistryBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct LazyIocMiddleware {
middleware_name: String,
registry: Arc<MiddlewareRegistry>,
}
impl std::fmt::Debug for LazyIocMiddleware {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LazyIocMiddleware")
.field("middleware_name", &self.middleware_name)
.finish()
}
}
impl LazyIocMiddleware {
pub fn new(middleware_name: String, registry: Arc<MiddlewareRegistry>) -> Self {
Self {
middleware_name,
registry,
}
}
}
impl Middleware for LazyIocMiddleware {
fn handle(&self, request: ElifRequest, next: Next) -> NextFuture<'static> {
let middleware_name = self.middleware_name.clone();
let registry = self.registry.clone();
Box::pin(async move {
let scope_result = registry.container.create_scope();
let scope = scope_result.ok();
match registry.create_middleware(&middleware_name, scope.as_ref()) {
Ok(middleware) => {
let result = middleware.handle(request, next).await;
if let Some(scope_id) = scope {
let _ = registry.container.dispose_scope(&scope_id).await;
}
result
}
Err(e) => {
eprintln!(
"CRITICAL: Failed to instantiate middleware '{}': {:?}",
middleware_name, e
);
if let Some(scope_id) = scope {
let _ = registry.container.dispose_scope(&scope_id).await;
}
ElifResponse::internal_server_error()
.json(&serde_json::json!({
"error": {
"code": "MIDDLEWARE_INIT_FAILED",
"message": "Internal server error",
"hint": "A required middleware component failed to initialize"
}
}))
.unwrap_or_else(|_| ElifResponse::internal_server_error())
}
}
})
}
fn name(&self) -> &'static str {
"LazyIocMiddleware"
}
}
#[derive(Clone, Debug)]
pub struct MiddlewareContext {
pub request_id: String,
pub user_id: Option<String>,
pub session_id: Option<String>,
pub correlation_id: Option<String>,
pub custom_data: HashMap<String, String>,
}
impl MiddlewareContext {
pub fn from_request(request: &ElifRequest) -> Self {
Self {
request_id: request
.header("x-request-id")
.and_then(|h| h.to_str().ok())
.unwrap_or("unknown")
.to_string(),
user_id: request
.header("x-user-id")
.and_then(|h| h.to_str().ok())
.map(String::from),
session_id: request
.header("x-session-id")
.and_then(|h| h.to_str().ok())
.map(String::from),
correlation_id: request
.header("x-correlation-id")
.and_then(|h| h.to_str().ok())
.map(String::from),
custom_data: HashMap::new(),
}
}
pub fn with_data(mut self, key: String, value: String) -> Self {
self.custom_data.insert(key, value);
self
}
}
pub struct MiddlewareGroup {
name: String,
middleware_names: Vec<String>,
registry: Arc<MiddlewareRegistry>,
}
impl MiddlewareGroup {
pub fn new(
name: String,
middleware_names: Vec<String>,
registry: Arc<MiddlewareRegistry>,
) -> Self {
Self {
name,
middleware_names,
registry,
}
}
pub fn create_middleware(
&self,
scope: Option<&ScopeId>,
) -> Result<Vec<Arc<dyn Middleware>>, HttpError> {
self.middleware_names
.iter()
.map(|name| self.registry.create_middleware(name, scope))
.collect()
}
pub fn name(&self) -> &str {
&self.name
}
pub fn middleware_names(&self) -> &[String] {
&self.middleware_names
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::middleware::v2::Middleware;
use elif_core::container::{IocContainer, ServiceBinder};
#[derive(Default, Clone, Debug)]
pub struct TestLoggerService {
pub name: String,
}
unsafe impl Send for TestLoggerService {}
unsafe impl Sync for TestLoggerService {}
#[derive(Debug)]
pub struct TestIocMiddleware {
logger: Arc<TestLoggerService>,
}
impl IocMiddleware for TestIocMiddleware {
fn from_ioc_container(
container: &IocContainer,
_scope: Option<&ScopeId>,
) -> Result<Self, String> {
let logger = container
.resolve::<TestLoggerService>()
.map_err(|e| format!("Failed to resolve TestLoggerService: {}", e))?;
Ok(Self { logger })
}
}
impl Middleware for TestIocMiddleware {
fn handle(&self, request: ElifRequest, next: Next) -> NextFuture<'static> {
let logger_name = self.logger.name.clone();
Box::pin(async move {
println!("TestIocMiddleware: Using logger: {}", logger_name);
next.run(request).await
})
}
fn name(&self) -> &'static str {
"TestIocMiddleware"
}
}
#[tokio::test]
async fn test_ioc_middleware_creation() {
let mut container = IocContainer::new();
let logger_service = TestLoggerService {
name: "TestLogger".to_string(),
};
container.bind_instance::<TestLoggerService, TestLoggerService>(logger_service);
container.build().expect("Container build failed");
let container_arc = Arc::new(container);
let mut registry = MiddlewareRegistry::new(container_arc);
registry
.register::<TestIocMiddleware>("test_middleware")
.expect("Failed to register middleware");
let middleware = registry
.create_middleware("test_middleware", None)
.expect("Failed to create middleware");
assert_eq!(middleware.name(), "TestIocMiddleware");
}
#[tokio::test]
async fn test_middleware_registry_builder() {
let mut container = IocContainer::new();
container.bind::<TestLoggerService, TestLoggerService>();
container.build().expect("Container build failed");
let registry = MiddlewareRegistryBuilder::new()
.container(Arc::new(container))
.register::<TestIocMiddleware>("test_ioc")
.build()
.expect("Failed to build middleware registry");
let registered = registry.registered_middleware();
assert!(registered.contains(&"test_ioc".to_string()));
let middleware = registry
.create_middleware("test_ioc", None)
.expect("Failed to create middleware");
assert_eq!(middleware.name(), "TestIocMiddleware");
}
#[tokio::test]
async fn test_middleware_pipeline_creation() {
let mut container = IocContainer::new();
container.bind::<TestLoggerService, TestLoggerService>();
container.build().expect("Container build failed");
let registry = MiddlewareRegistryBuilder::new()
.container(Arc::new(container))
.register::<TestIocMiddleware>("ioc1")
.register::<TestIocMiddleware>("ioc2")
.build()
.expect("Failed to build middleware registry");
let middleware_pipeline = registry
.create_middleware_pipeline(&["ioc1", "ioc2"], None)
.expect("Failed to create middleware pipeline");
assert_eq!(middleware_pipeline.len(), 2);
}
#[tokio::test]
async fn test_lazy_ioc_middleware() {
use crate::request::method::ElifMethod as HttpMethod;
use crate::response::headers::ElifHeaderMap;
let mut container = IocContainer::new();
container.bind::<TestLoggerService, TestLoggerService>();
container.build().expect("Container build failed");
let registry = Arc::new(
MiddlewareRegistryBuilder::new()
.container(Arc::new(container))
.register::<TestIocMiddleware>("lazy_test")
.build()
.expect("Failed to build middleware registry"),
);
let lazy_middleware = LazyIocMiddleware::new("lazy_test".to_string(), registry);
let request = ElifRequest::new(
HttpMethod::GET,
"/test".parse().unwrap(),
ElifHeaderMap::new(),
);
let next = Next::new(|_req| Box::pin(async { ElifResponse::ok().text("Success") }));
let response = lazy_middleware.handle(request, next).await;
assert_eq!(
response.status_code(),
crate::response::status::ElifStatusCode::OK
);
}
#[tokio::test]
async fn test_lazy_ioc_middleware_instantiation_failure() {
use crate::request::method::ElifMethod as HttpMethod;
use crate::response::headers::ElifHeaderMap;
let container = IocContainer::new();
let registry = Arc::new(
MiddlewareRegistryBuilder::new()
.container(Arc::new(container))
.register::<TestIocMiddleware>("failing_middleware")
.build()
.expect("Failed to build middleware registry"),
);
let lazy_middleware = LazyIocMiddleware::new("failing_middleware".to_string(), registry);
let request = ElifRequest::new(
HttpMethod::GET,
"/test".parse().unwrap(),
ElifHeaderMap::new(),
);
let next = Next::new(|_req| {
Box::pin(async {
panic!("Next middleware should not be called when middleware instantiation fails!");
})
});
let response = lazy_middleware.handle(request, next).await;
assert_eq!(
response.status_code(),
crate::response::status::ElifStatusCode::INTERNAL_SERVER_ERROR
);
}
#[tokio::test]
async fn test_middleware_context_from_request() {
use crate::request::method::ElifMethod as HttpMethod;
use crate::response::headers::{ElifHeaderMap, ElifHeaderName, ElifHeaderValue};
let mut headers = ElifHeaderMap::new();
headers.insert(
ElifHeaderName::from_str("x-request-id").unwrap(),
ElifHeaderValue::from_str("req-123").unwrap(),
);
headers.insert(
ElifHeaderName::from_str("x-user-id").unwrap(),
ElifHeaderValue::from_str("user-456").unwrap(),
);
let request = ElifRequest::new(HttpMethod::POST, "/api/test".parse().unwrap(), headers);
let context = MiddlewareContext::from_request(&request);
assert_eq!(context.request_id, "req-123");
assert_eq!(context.user_id, Some("user-456".to_string()));
assert!(context.session_id.is_none());
}
#[tokio::test]
async fn test_middleware_group() {
let mut container = IocContainer::new();
container.bind::<TestLoggerService, TestLoggerService>();
container.build().expect("Container build failed");
let registry = Arc::new(
MiddlewareRegistryBuilder::new()
.container(Arc::new(container))
.register::<TestIocMiddleware>("group1")
.register::<TestIocMiddleware>("group2")
.build()
.expect("Failed to build middleware registry"),
);
let group = MiddlewareGroup::new(
"test_group".to_string(),
vec!["group1".to_string(), "group2".to_string()],
registry,
);
assert_eq!(group.name(), "test_group");
assert_eq!(group.middleware_names().len(), 2);
let middleware = group
.create_middleware(None)
.expect("Failed to create group middleware");
assert_eq!(middleware.len(), 2);
}
}