use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use parking_lot::Mutex;
pub use axum::{
body::{Body, BoxBody, Bytes, HttpBody},
BoxError,
handler::Handler,
http::Request,
http::StatusCode,
http::uri::Scheme,
Json,
response::Response,
extract::Extension
};
pub use hyper::{
upgrade,
client::conn
};
#[cfg(feature = "tls")]
pub use axum_server::{
service::*
};
pub use product_os_router::{ Layer, Service, ServiceExt, service_fn };
pub use product_os_router::Router;
pub use product_os_router::Method;
pub use product_os_router::MethodRouter;
pub use product_os_router::ProductOSRouter;
pub use product_os_router::ServiceBuilder;
#[cfg(feature = "middleware")]
pub use product_os_router::*;
#[cfg(feature = "ws")]
pub use axum::{
extract::*,
};
use axum::response::{IntoResponse, Redirect};
#[cfg(feature = "sse")]
pub use axum::response::sse::{Event, Sse};
#[cfg(feature = "compression")]
use tower_http::{
compression::CompressionLayer,
decompression::DecompressionLayer
};
#[cfg(feature = "csrf")]
use axum_csrf::{CsrfConfig, CsrfLayer};
use product_os_configuration::Configuration;
use product_os_security::certificates::Certificates;
use tokio::runtime::Handle;
use tracing::Level;
#[cfg(feature = "cors")]
use tower_http::cors::CorsLayer;
pub struct ProductOSServerSimple<S = ()> {
router: ProductOSRouter<S>,
config: product_os_configuration::Configuration,
certificates: Certificates
}
impl<S> ProductOSServerSimple<S>
where
S: Clone + Send + Sync + 'static {
#[tokio::main]
pub async fn new_with_config_sync(config: product_os_configuration::Configuration, mut router: ProductOSRouter<S>) -> Self {
ProductOSServerSimple::new_with_config(config, router).await
}
pub async fn new_with_config(config: product_os_configuration::Configuration, mut router: ProductOSRouter<S>) -> Self {
crate::logging::set_global_logger(crate::logging::define_logging(config.log_level()));
tracing::info!("Log Level: {}", config.log_level());
let certificates = match config.certificate.clone() {
None => {
tracing::info!("Generating self-signed certificate");
product_os_security::certificates::Certificates::new(Some(vec!(config.get_host())))
},
Some(cert_config) => {
tracing::info!("Using configuration file for certificate");
product_os_security::certificates::Certificates::new_from_file(cert_config.cert_file, cert_config.key_file)
}
};
match &config.security {
None => {}
Some(security) => {
if security.enable {
router.add_default_header("content-security-policy".to_string(), crate::csp::ContentSecurityPolicy::new(&config).get_csp());
router.add_default_header("cross-origin-embedder-policy".to_string(), "require-corp".to_string());
router.add_default_header("cross-origin-opener-policy".to_string(), "same-origin".to_string());
router.add_default_header("referrer-policy".to_string(), "strict-origin-when-cross-origin".to_string());
router.add_default_header("strict-transport-security".to_string(), "max-age=86400; includeSubDomains; preload".to_string());
router.add_default_header("x-Content-type-options".to_string(), "nosniff".to_string());
router.add_default_header("x-powered-by".to_string(), config.network.host.to_owned());
#[cfg(feature = "csrf")]
{
if security.csrf {
router.add_middleware(CsrfLayer::new(CsrfConfig::default()));
tracing::info!("CSRF added as extension middleware");
}
}
}
}
}
#[cfg(feature = "compression")]
{
match &config.compression {
None => {}
Some(compress) => {
if compress.enable {
let mut compression = CompressionLayer::new();
if config.is_compression_gzip() { compression = compression.gzip(true); }
if config.is_compression_deflate() { compression = compression.deflate(true); }
if config.is_compression_brotli() { compression = compression.br(true); }
router.add_middleware(compression);
let mut decompression = DecompressionLayer::new();
if config.is_compression_gzip() { decompression = decompression.gzip(true); }
if config.is_compression_deflate() { decompression = decompression.deflate(true); }
if config.is_compression_brotli() { decompression = decompression.br(true); }
router.add_middleware(decompression);
}
}
}
}
Self {
router,
config,
certificates
}
}
pub fn new(router: ProductOSRouter<S>) -> Self {
let config = Configuration::new();
let log_level = config.log_level();
crate::logging::set_global_logger(crate::logging::define_logging(log_level));
tracing::info!("Log Level: {}", config.log_level());
let certificates = {
tracing::info!("Generating self-signed certificate on localhost");
Certificates::new(Some(vec!(config.get_host())))
};
Self {
router,
config,
certificates
}
}
pub fn set_logging(&mut self, log_level: Level) {
crate::logging::set_global_logger(crate::logging::define_logging(log_level));
tracing::info!("Log Level: {}", self.config.log_level());
}
pub fn set_certificate(&mut self, certificate: Option<product_os_configuration::Certificate>) {
self.certificates = match certificate.clone() {
None => {
tracing::info!("Generating self-signed certificate");
Certificates::new(Some(vec!(self.config.get_host())))
},
Some(cert_config) => {
tracing::info!("Using configuration file for certificate");
Certificates::new_from_file(cert_config.cert_file, cert_config.key_file)
}
};
}
pub fn set_security(&mut self, security: Option<product_os_configuration::Security>) {
match security {
None => {}
Some(security) => {
if security.enable {
self.router.add_default_header("content-security-policy".to_string(), crate::csp::ContentSecurityPolicy::new(&self.config).get_csp());
self.router.add_default_header("cross-origin-embedder-policy".to_string(), "require-corp".to_string());
self.router.add_default_header("cross-origin-opener-policy".to_string(), "same-origin".to_string());
self.router.add_default_header("referrer-policy".to_string(), "strict-origin-when-cross-origin".to_string());
self.router.add_default_header("strict-transport-security".to_string(), "max-age=86400; includeSubDomains; preload".to_string());
self.router.add_default_header("x-Content-type-options".to_string(), "nosniff".to_string());
self.router.add_default_header("x-powered-by".to_string(), self.config.network.host.to_owned());
#[cfg(feature = "csrf")]
{
if security.csrf {
self.router.add_middleware(CsrfLayer::new(CsrfConfig::default()));
tracing::info!("CSRF added as extension middleware");
}
}
}
}
}
}
#[cfg(feature = "compression")]
pub fn set_compression(&mut self, compression_config: Option<product_os_configuration::Compression>) {
match compression_config {
None => {}
Some(compression_config) => {
if compression_config.enable {
let mut compression = CompressionLayer::new();
if compression_config.gzip { compression = compression.gzip(true); }
if compression_config.deflate { compression = compression.deflate(true); }
if compression_config.brotli { compression = compression.br(true); }
self.router.add_middleware(compression);
let mut decompression = DecompressionLayer::new();
if compression_config.gzip { decompression = decompression.gzip(true); }
if compression_config.deflate { decompression = decompression.deflate(true); }
if compression_config.brotli { decompression = decompression.br(true); }
self.router.add_middleware(decompression);
}
}
}
}
pub fn get_router(&mut self) -> &mut ProductOSRouter<S> {
&mut self.router
}
pub fn add_route(&mut self, path: &str, method_router: MethodRouter<S>) {
self.router.add_route(path, method_router);
}
pub fn set_fallback(&mut self, method_router: MethodRouter<S>) {
self.router.set_fallback(method_router);
}
pub fn add_get<H, T>(&mut self, path: &str, handler: H)
where
H: Handler<T, S, Body>,
T: 'static
{
self.router.add_get(path, handler);
}
pub fn add_post<H, T>(&mut self, path: &str, handler: H)
where
H: Handler<T, S, Body>,
T: 'static
{
self.router.add_post(path, handler);
}
pub fn add_handler<H, T>(&mut self, path: &str, method: Method, handler: H)
where
H: Handler<T, S, Body>,
T: 'static
{
self.router.add_handler(path, method, handler);
}
pub fn set_fallback_handler<H, T>(&mut self, handler: H)
where
H: Handler<T, S, Body>,
T: 'static
{
self.router.set_fallback_handler(handler);
}
#[cfg(feature = "cors")]
pub fn add_cors_handler<H, T>(&mut self, path: &str, method: Method, handler: H)
where
H: Handler<T, S, Body>,
T: 'static
{
self.router.add_cors_handler(path, method, handler);
}
#[cfg(feature = "ws")]
pub fn add_ws_handler<H, T>(&mut self, path: &str, ws_handler: H)
where
H: Handler<T, S, Body>,
T: 'static
{
self.add_get(path, ws_handler);
}
#[cfg(feature = "sse")]
pub fn add_sse_handler<H, T>(&mut self, path: &str, sse_handler: H)
where
H: Handler<T, S, Body>,
T: 'static
{
self.add_get(path, sse_handler);
}
pub fn add_handlers<H, T>(&mut self, path: &str, handlers: HashMap<Method, H>)
where
H: Handler<T, S, Body>,
T: 'static
{
self.router.add_handlers(path, handlers);
}
#[cfg(feature = "cors")]
pub fn add_cors_handlers<H, T>(&mut self, path: &str, handlers: HashMap<Method, H>)
where
H: Handler<T, S, Body>,
T: 'static
{
self.router.add_cors_handlers(path, handlers);
}
pub fn add_middleware<L, NewResBody>(&mut self, middleware: L)
where
L: Layer<product_os_router::Route<Body>> + Clone + Send + 'static,
L::Service: Service<Request<Body>, Response = Response<NewResBody>, Error = Infallible> + Clone + Send + 'static,
<L::Service as Service<Request<Body>>>::Future: Send + 'static,
NewResBody: HttpBody<Data = Bytes> + Send + 'static,
NewResBody::Error: Into<BoxError>,
{
self.router.add_middleware(middleware);
}
pub fn set_router(&mut self, router: ProductOSRouter<S>) {
self.router = router;
}
pub async fn create_dual_service_server(&mut self, serve_on_main_thread: bool, custom_port: Option<u16>, custom_router: Option<Router>, with_connect_info: bool, force_secure: bool) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
#[cfg(feature = "tls")]
{
let router = match custom_router {
None => self.router.get_router(),
Some(r) => r
};
let address = self.config.socket_address(custom_port);
match crate::dual_server::create_dual_service(address, Some(self.certificates.to_owned())) {
Ok(server) => {
tracing::info!("HTTPS and HTTP server listening on {}", address);
if serve_on_main_thread {
if with_connect_info {
server
.set_upgrade(force_secure)
.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await.unwrap();
}
else {
server
.set_upgrade(force_secure)
.serve(router.into_make_service())
.await.unwrap();
}
}
else {
if with_connect_info {
tokio::spawn(async move {
server
.set_upgrade(force_secure)
.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await.unwrap();
});
}
else {
tokio::spawn(async move {
server
.set_upgrade(force_secure)
.serve(router.into_make_service())
.await.unwrap();
});
}
}
},
Err(e) => tracing::error!("Error starting HTTPS server: {}", e)
}
}
#[cfg(not(feature = "tls"))]
{
tracing::info!("TLS feature is not enabled - please include the \"tls\" feature in your toml config file");
}
Ok(())
}
pub async fn create_https_server(&mut self, serve_on_main_thread: bool, custom_port: Option<u16>, custom_router: Option<Router>, with_connect_info: bool) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
#[cfg(feature = "tls")]
{
let router = match custom_router {
None => self.router.get_router(),
Some(r) => r
};
let address = self.config.socket_address(custom_port);
match crate::https_server::create_https_service(address, Some(self.certificates.to_owned())) {
Ok(server) => {
tracing::info!("HTTPS server listening on {}", address);
if serve_on_main_thread {
if with_connect_info {
server.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await.unwrap();
}
else {
server.serve(router.into_make_service())
.await.unwrap();
}
}
else {
if with_connect_info {
tokio::spawn(async move {
server.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await.unwrap();
});
}
else {
tokio::spawn(async move {
server.serve(router.into_make_service())
.await.unwrap();
});
}
}
},
Err(e) => tracing::error!("Error starting HTTPS server: {}", e)
}
}
#[cfg(not(feature = "tls"))]
{
tracing::info!("TLS feature is not enabled - please include the \"tls\" feature in your toml config file");
}
Ok(())
}
pub async fn create_http_server(&mut self, serve_on_main_thread: bool, custom_port: Option<u16>, custom_router: Option<Router>, with_connect_info: bool) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let router: Router = match custom_router {
None => self.router.get_router(),
Some(r) => r
};
let address = self.config.socket_address(custom_port);
match crate::http_server::create_http_service(address) {
Ok(server) => {
tracing::info!("HTTP server listening on {}", address);
if serve_on_main_thread {
if with_connect_info {
server.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await.unwrap();
}
else {
server.serve(router.into_make_service())
.await.unwrap();
}
}
else {
if with_connect_info {
tokio::spawn(async move {
server.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await.unwrap();
});
}
else {
tokio::spawn(async move {
server.serve(router.into_make_service())
.await.unwrap();
});
}
}
},
Err(e) => tracing::error!("Error starting HTTP server: {}", e)
}
Ok(())
}
#[cfg(feature = "custom")]
pub async fn create_dual_server_custom(&mut self, serve_on_main_thread: bool, custom_port: Option<u16>, force_secure: bool, service_function: impl FnMut<(Request<Body>,), Output = impl Future<Output = Result<Response, Infallible>> + Send + 'static> + Send + Clone + 'static) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
#[cfg(feature = "tls")]
{
let address = self.config.socket_address(custom_port);
match dual_server::create_dual_service(address, Some(self.certificates.to_owned())) {
Ok(server) => {
tracing::info!("HTTPS custom server listening on {}", address);
let service = product_os_router::service_fn(service_function);
let make_service = product_os_router::Shared::new(service);
if serve_on_main_thread {
if with_connect_info {
server
.set_upgrade(force_secure)
.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await.unwrap();
}
else {
server
.set_upgrade(force_secure)
.serve(router.into_make_service())
.await.unwrap();
}
}
else {
if with_connect_info {
tokio::spawn(async move {
server
.set_upgrade(force_secure)
.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await.unwrap();
});
}
else {
tokio::spawn(async move {
server
.set_upgrade(force_secure)
.serve(router.into_make_service())
.await.unwrap();
});
}
}
},
Err(e) => tracing::error!("Error starting custom HTTPS server: {}", e)
}
}
#[cfg(not(feature = "tls"))]
{
tracing::info!("TLS feature is not enabled - please include the \"tls\" feature in your toml config file");
}
Ok(())
}
#[cfg(feature = "custom")]
pub async fn create_https_server_custom(&mut self, serve_on_main_thread: bool, custom_port: Option<u16>, service_function: impl FnMut<(Request<Body>,), Output = impl Future<Output = Result<Response, Infallible>> + Send + 'static> + Send + Clone + 'static) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
#[cfg(feature = "tls")]
{
let address = self.config.socket_address(custom_port);
match crate::https_server::create_https_service(address, Some(self.certificates.to_owned())) {
Ok(server) => {
tracing::info!("HTTPS custom server listening on {}", address);
let service = product_os_router::service_fn(service_function);
let make_service = product_os_router::Shared::new(service);
if serve_on_main_thread {
if with_connect_info {
server.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await.unwrap();
}
else {
server.serve(router.into_make_service())
.await.unwrap();
}
}
else {
if with_connect_info {
tokio::spawn(async move {
server.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await.unwrap();
});
}
else {
tokio::spawn(async move {
server.serve(router.into_make_service())
.await.unwrap();
});
}
}
},
Err(e) => tracing::error!("Error starting custom HTTPS server: {}", e)
}
}
#[cfg(not(feature = "tls"))]
{
tracing::info!("TLS feature is not enabled - please include the \"tls\" feature in your toml config file");
}
Ok(())
}
#[cfg(feature = "custom")]
pub async fn create_http_server_custom(&mut self, serve_on_main_thread: bool, custom_port: Option<u16>, service_function: impl FnMut<(Request<Body>,), Output = impl Future<Output = Result<Response, Infallible>> + Send + 'static> + Send + Clone + 'static) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
{
let address = self.config.socket_address(custom_port);
match crate::http_server::create_http_service(address) {
Ok(server) => {
tracing::info!("HTTP custom server listening on {}", address);
let service = product_os_router::service_fn(service_function);
let make_service = Shared::new(service);
if serve_on_main_thread {
if with_connect_info {
server.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await.unwrap();
}
else {
server.serve(router.into_make_service())
.await.unwrap();
}
}
else {
if with_connect_info {
tokio::spawn(async move {
server.serve(router.into_make_service_with_connect_info::<SocketAddr>())
.await.unwrap();
});
}
else {
tokio::spawn(async move {
server.serve(router.into_make_service())
.await.unwrap();
});
}
}
},
Err(e) => tracing::error!("Error starting HTTP server: {}", e)
}
Ok(())
}
pub async fn init_stores(&mut self) {
}
pub fn get_config(&self) -> Configuration {
self.config.clone()
}
pub fn update_config(&mut self, config: Configuration) {
self.config = config;
}
pub fn get_base_url(&self) -> url::Url {
self.config.url_address().to_owned()
}
#[cfg(feature = "custom")]
#[tokio::main]
pub async fn start_custom_sync(&mut self, service_function: impl FnMut<(Request<Body>,), Output = impl Future<Output = Result<Response, Infallible>> + Send + 'static> + Send + Clone + 'static) -> Result<Handle, Box<dyn std::error::Error + Send + Sync>>
{
self.start_custom(service_function).await
}
#[cfg(feature = "custom")]
pub async fn start_custom(&mut self, service_function: impl FnMut<(Request<Body>,), Output = impl Future<Output = Result<Response, Infallible>> + Send + 'static> + Send + Clone + 'static) -> Result<Handle, Box<dyn std::error::Error + Send + Sync>>
{
if self.config.is_secure() {
if self.config.network.allow_insecure {
if self.config.network.insecure_force_secure {
self.create_http_server(false, Some(self.config.network.insecure_port), service, false).await?;
}
else {
self.create_http_server(false, Some(self.config.network.insecure_port), service, false).await?;
}
}
self.create_https_server(true, None, None, false).await?;
}
else {
self.create_http_server(true, None, None, false).await?;
};
let handle = Handle::current();
Ok(handle)
}
#[tokio::main]
pub async fn start_sync(&mut self) -> Result<Handle, Box<dyn std::error::Error + Send + Sync>> {
self.start().await
}
pub async fn start(&mut self) -> Result<Handle, Box<dyn std::error::Error + Send + Sync>> {
if self.config.is_secure() {
if self.config.network.allow_insecure {
if !self.config.network.insecure_use_different_port {
self.create_dual_service_server(false, Some(self.config.insecure_port()), None, false, self.config.network.insecure_force_secure).await?;
}
else {
if self.config.network.insecure_force_secure {
let mut router = Router::new();
router = router.fallback(MethodRouter::new()
.get(force_secure_handler)
.post(force_secure_handler)
.put(force_secure_handler)
.patch(force_secure_handler)
.delete(force_secure_handler)
.trace(force_secure_handler)
.head(force_secure_handler)
.options(force_secure_handler));
self.create_http_server(false, Some(self.config.network.insecure_port), Some(router), false).await?;
}
else {
self.create_http_server(false, Some(self.config.network.insecure_port), None, false).await?;
}
}
}
self.create_https_server(true, None, None, false).await?;
}
else {
self.create_http_server(true, None, None, false).await?;
};
let handle = Handle::current();
Ok(handle)
}
}
async fn force_secure_handler(request: Request<Body>) -> Response {
let uri_path = request.uri().path();
let mut url: String = String::new();
url.push_str(uri_path);
Redirect::permanent(url.as_str()).into_response()
}