use std::sync::Arc;
use arc_swap::ArcSwap;
use super::RewriteConfig;
use super::pipeline::driver::Pipeline;
use super::pipeline::middleware::{RequestMiddleware, ResponseMiddleware};
use super::pipeline::middlewares::{
ClientInfoInjectMiddleware, CspRewriteMiddleware, EnvelopeSealMiddleware,
HealthTrackMiddleware, SchemaIngestMiddleware, SchemaStaleMiddleware, SessionDeleteMiddleware,
SessionRecordMiddleware, SessionTouchMiddleware, TargetExtractMiddleware, UrlMapMiddleware,
};
use super::router::ProxyRouter;
use super::transport::ProxyTransport;
pub type ProxyPipeline = Pipeline<ProxyRouter, ProxyTransport>;
pub fn build_default_pipeline(rewrite_config: Arc<ArcSwap<RewriteConfig>>) -> ProxyPipeline {
let request_chain: Vec<Box<dyn RequestMiddleware>> = vec![
Box::new(SessionDeleteMiddleware),
Box::new(SessionTouchMiddleware),
Box::new(ClientInfoInjectMiddleware),
Box::new(TargetExtractMiddleware),
];
let response_chain: Vec<Box<dyn ResponseMiddleware>> = vec![
Box::new(SchemaIngestMiddleware),
Box::new(SchemaStaleMiddleware),
Box::new(CspRewriteMiddleware::new(rewrite_config.clone())),
Box::new(SessionRecordMiddleware),
Box::new(HealthTrackMiddleware),
Box::new(UrlMapMiddleware::new(rewrite_config)),
Box::new(EnvelopeSealMiddleware),
];
for mw in &request_chain {
tracing::info!(chain = "request", name = mw.name(), "middleware registered");
}
for mw in &response_chain {
tracing::info!(
chain = "response",
name = mw.name(),
"middleware registered"
);
}
Pipeline::new(request_chain, response_chain, ProxyRouter, ProxyTransport)
}
#[cfg(test)]
#[allow(non_snake_case)]
mod tests {
use super::*;
use crate::proxy::pipeline::middlewares::test_support::test_proxy_state;
#[tokio::test]
async fn build_default_pipeline__registers_expected_chain_names_in_order() {
let proxy = test_proxy_state();
let pipeline = build_default_pipeline(proxy.rewrite_config.clone());
assert_eq!(
pipeline.request_chain_names(),
vec![
"session_delete",
"session_touch",
"client_info_inject",
"target_extract",
],
);
assert_eq!(
pipeline.response_chain_names(),
vec![
"schema_ingest",
"schema_stale",
"csp_rewrite",
"session_record",
"health_track",
"url_map",
"envelope_seal",
],
);
}
}