use super::{bind::bind_http, config::HttpConfig};
use crate::actix::http::{BuildCors, CorsConfig};
use crate::app::{Startup, StartupExt};
use crate::{
app::RuntimeConfig,
core::tls::{TlsAuthConfig, WithTlsAuthConfig},
};
use actix_cors::Cors;
use actix_http::Extensions;
use actix_web::{
middleware,
web::{self, ServiceConfig},
App, HttpServer,
};
use actix_web_extras::middleware::Condition;
use futures_core::future::BoxFuture;
use futures_util::{FutureExt, TryFutureExt};
use std::any::Any;
pub type OnConnectFn = dyn Fn(&dyn Any, &mut Extensions) + Send + Sync + 'static;
pub struct HttpBuilder<F>
where
F: Fn(&mut ServiceConfig) + Send + Clone + 'static,
{
config: HttpConfig,
default_cors: Option<CorsConfig>,
app_builder: Box<F>,
on_connect: Option<Box<OnConnectFn>>,
tls_auth_config: TlsAuthConfig,
tracing: bool,
}
impl<F> HttpBuilder<F>
where
F: Fn(&mut ServiceConfig) + Send + Clone + 'static,
{
pub fn new(config: HttpConfig, runtime: Option<&RuntimeConfig>, app_builder: F) -> Self {
Self {
config,
default_cors: None,
app_builder: Box::new(app_builder),
on_connect: None,
tls_auth_config: TlsAuthConfig::default(),
tracing: runtime.map(|r| r.tracing.is_enabled()).unwrap_or_default(),
}
}
pub fn default_cors<C: Into<Option<CorsConfig>>>(mut self, default_cors: C) -> Self {
self.default_cors = default_cors.into();
self
}
pub fn on_connect<O>(mut self, on_connect: O) -> Self
where
O: Fn(&dyn Any, &mut Extensions) + Send + Sync + 'static,
{
self.on_connect = Some(Box::new(on_connect));
self
}
pub fn tls_auth_config<I: Into<TlsAuthConfig>>(mut self, tls_auth_config: I) -> Self {
self.tls_auth_config = tls_auth_config.into();
self
}
pub fn start(self, startup: &mut dyn Startup) -> anyhow::Result<()> {
startup.spawn(self.run()?);
Ok(())
}
fn cors_config(&self) -> Option<CorsConfig> {
self.config
.cors
.as_ref()
.or(self.default_cors.as_ref())
.cloned()
}
pub fn run(
#[allow(unused_mut)] mut self,
) -> Result<BoxFuture<'static, Result<(), anyhow::Error>>, anyhow::Error> {
let max_payload_size = self.config.max_payload_size;
let max_json_payload_size = self.config.max_json_payload_size;
let prometheus = actix_web_prom::PrometheusMetricsBuilder::new(
self.config.metrics_namespace.as_deref().unwrap_or("drogue"),
)
.registry(prometheus::default_registry().clone())
.build()
.map_err(|err| anyhow::anyhow!("Failed to build prometheus middleware: {err}"))?;
let cors = self.cors_config();
log::debug!("Effective CORS config {cors:?}");
let _: Option<Cors> = cors.build_cors()?;
let mut main = HttpServer::new(move || {
let app = App::new();
let cors: Option<Cors> = cors.build_cors().expect("Configuration must be valid");
let app = app.wrap(Condition::from_option(cors));
let app = app.wrap(prometheus.clone());
let (logger, tracing_logger) = match self.tracing {
false => (Some(middleware::Logger::default()), None),
true => (None, Some(tracing_actix_web::TracingLogger::default())),
};
log::debug!(
"Loggers ({}) - logger: {}, tracing: {}",
self.tracing,
logger.is_some(),
tracing_logger.is_some()
);
let app = app
.wrap(Condition::from_option(logger))
.wrap(Condition::from_option(tracing_logger));
let app = app
.app_data(web::PayloadConfig::new(max_payload_size))
.app_data(web::JsonConfig::default().limit(max_json_payload_size));
app.configure(|cfg| (self.app_builder)(cfg))
});
if let Some(on_connect) = self.on_connect {
main = main.on_connect(on_connect);
}
if self.config.disable_tls_psk {
#[cfg(feature = "openssl")]
self.tls_auth_config.psk.take();
}
let mut main = bind_http(
main,
self.config.bind_addr,
self.config
.disable_tls
.with_tls_auth_config(self.tls_auth_config),
self.config.key_file,
self.config.cert_bundle_file,
)?;
if let Some(workers) = self.config.workers {
main = main.workers(workers)
}
Ok(main.run().err_into().boxed())
}
}