use actix_web::body::MessageBody;
use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready};
use actix_web::{Error, body::BoxBody};
use futures::future::{LocalBoxFuture, Ready, ready};
use std::sync::Arc;
use tracing::debug;
use super::engine::GuardrailEngine;
use super::types::GuardrailAction;
pub struct GuardrailMiddleware {
engine: Arc<GuardrailEngine>,
}
impl GuardrailMiddleware {
pub fn new(engine: Arc<GuardrailEngine>) -> Self {
Self { engine }
}
}
impl<S, B> Transform<S, ServiceRequest> for GuardrailMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: MessageBody + 'static,
{
type Response = ServiceResponse<BoxBody>;
type Error = Error;
type InitError = ();
type Transform = GuardrailMiddlewareService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(GuardrailMiddlewareService {
service,
engine: self.engine.clone(),
}))
}
}
pub struct GuardrailMiddlewareService<S> {
service: S,
engine: Arc<GuardrailEngine>,
}
impl<S, B> Service<ServiceRequest> for GuardrailMiddlewareService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: MessageBody + 'static,
{
type Response = ServiceResponse<BoxBody>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let engine = self.engine.clone();
let path = req.path().to_string();
if engine.is_path_excluded(&path) {
let fut = self.service.call(req);
return Box::pin(async move {
let res = fut.await?;
Ok(res.map_into_boxed_body())
});
}
let fut = self.service.call(req);
Box::pin(async move {
debug!("Guardrails middleware processing request to {}", path);
let res = fut.await?;
Ok(res.map_into_boxed_body())
})
}
}
#[derive(Debug, Clone)]
pub struct GuardrailCheckContext {
pub passed: bool,
pub action: GuardrailAction,
pub violation_count: usize,
pub modified_content: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::guardrails::config::GuardrailConfig;
#[test]
fn test_middleware_creation() {
let config = GuardrailConfig::default();
let engine = Arc::new(GuardrailEngine::new(config).unwrap());
let _middleware = GuardrailMiddleware::new(engine);
}
#[test]
fn test_guardrail_check_context() {
let context = GuardrailCheckContext {
passed: true,
action: GuardrailAction::Allow,
violation_count: 0,
modified_content: None,
};
assert!(context.passed);
assert_eq!(context.violation_count, 0);
}
}