use std::{
collections::{BTreeMap, HashMap},
net::SocketAddr,
sync::{Arc, Mutex, OnceLock},
};
use crate::config::{matches_convert::convert_config_to_kernel, plugin_filter_dto::global_batch_mount_plugin, PluginConfig, SgProtocolConfig, SgTlsMode};
use hyper::Version;
use spacegate_config::{BackendHost, Config, ConfigItem};
use spacegate_kernel::{
helper_layers::reload::Reloader,
listener::SgListen,
service::http_gateway::{builder::default_gateway_route_fallback, create_http_router, HttpRouterService},
ArcHyperService, BoxError,
};
use spacegate_plugin::{mount::MountPointIndex, PluginRepository};
use std::time::Duration;
use std::vec::Vec;
use tokio::time::timeout;
use tracing::{debug, error, info, instrument, warn};
use tokio_rustls::rustls::{self, pki_types::PrivateKeyDer};
use tokio_util::sync::CancellationToken;
fn collect_http_route(
gateway_name: Arc<str>,
http_routes: impl IntoIterator<Item = (String, crate::SgHttpRoute)>,
) -> Result<HashMap<String, spacegate_kernel::service::http_route::HttpRoute>, BoxError> {
http_routes
.into_iter()
.map(|(name, route)| {
let route_name: Arc<str> = name.clone().into();
let mount_index = MountPointIndex::HttpRoute {
gateway: gateway_name.clone(),
route: route_name.clone(),
};
let plugins = route.plugins;
let rules = route.rules;
let rules = rules
.into_iter()
.enumerate()
.map(|(rule_index, route_rule)| {
let mount_index = MountPointIndex::HttpRouteRule {
rule: rule_index,
gateway: gateway_name.clone(),
route: route_name.clone(),
};
let mut builder = spacegate_kernel::service::http_route::HttpRouteRule::builder();
builder = if let Some(matches) = route_rule.matches {
builder.matches(matches.into_iter().map(convert_config_to_kernel).collect::<Result<Vec<_>, _>>()?)
} else {
builder.match_all()
};
let backends = route_rule
.backends
.into_iter()
.enumerate()
.map(|(backend_index, backend)| {
let mount_index = MountPointIndex::HttpBackend {
backend: backend_index,
rule: rule_index,
gateway: gateway_name.clone(),
route: route_name.clone(),
};
let host = backend.get_host();
let mut builder = spacegate_kernel::service::http_route::HttpBackend::builder();
let plugins = backend.plugins;
#[cfg(feature = "k8s")]
{
use crate::extension::k8s_service::K8sService;
use spacegate_kernel::helper_layers::map_request::{add_extension::add_extension, MapRequestLayer};
use spacegate_kernel::BoxLayer;
if let BackendHost::K8sService(ref data) = backend.host {
let namespace_ext = K8sService(data.clone().into());
builder = builder.plugin(BoxLayer::new(MapRequestLayer::new(add_extension(namespace_ext, true))))
}
}
builder = builder.host(host);
if let Some(port) = backend.port {
builder = builder.port(port)
}
if let Some(timeout) = backend.timeout_ms.map(|timeout| Duration::from_millis(timeout as u64)) {
builder = builder.timeout(timeout)
}
let mut layer = if let BackendHost::File { path } = backend.host {
builder.file().path(path).build()
} else if let Some(protocol) = backend.protocol {
builder.schema(protocol.to_string()).build()
} else {
builder.build()
};
if backend.downgrade_http2.is_some() {
if let spacegate_kernel::service::http_route::Backend::Http { version, .. } = &mut layer.backend {
version.replace(Version::HTTP_11);
}
}
global_batch_mount_plugin(plugins, &mut layer, mount_index);
Result::<_, BoxError>::Ok(layer)
})
.collect::<Result<Vec<_>, _>>()?;
builder = builder.backends(backends);
if let Some(timeout) = route_rule.timeout_ms {
builder = builder.timeout(Duration::from_millis(timeout as u64));
}
let mut layer = builder.build();
global_batch_mount_plugin(route_rule.plugins, &mut layer, mount_index);
Result::<_, BoxError>::Ok(layer)
})
.collect::<Result<Vec<_>, _>>()?;
let mut layer =
spacegate_kernel::service::http_route::HttpRoute::builder().hostnames(route.hostnames.unwrap_or_default()).rules(rules).priority(route.priority).build();
global_batch_mount_plugin(plugins, &mut layer, mount_index);
Ok((name, layer))
})
.collect::<Result<HashMap<String, _>, _>>()
}
pub(crate) fn create_service(item: ConfigItem, reloader: Reloader<HttpRouterService>) -> Result<ArcHyperService, BoxError> {
let gateway_name: Arc<str> = item.gateway.name.into();
let http_routes = item.routes;
let routes = collect_http_route(gateway_name.clone(), http_routes)?;
let plugins = item.gateway.plugins.clone();
let mut builder = spacegate_kernel::service::http_gateway::Gateway::builder(gateway_name.clone());
if let Some(enable) = item.gateway.parameters.enable_x_request_id {
builder = builder.x_request_id(enable);
}
let mut layer = builder.http_routers(routes).http_route_reloader(reloader).build();
global_batch_mount_plugin(plugins, &mut layer, MountPointIndex::Gateway { gateway: gateway_name });
let service = layer.as_service();
Ok(service)
}
pub(crate) fn create_router_service(gateway_name: Arc<str>, http_routes: BTreeMap<String, crate::SgHttpRoute>) -> Result<HttpRouterService, BoxError> {
let routes = collect_http_route(gateway_name, http_routes.clone())?;
let service = create_http_router(routes.values(), default_gateway_route_fallback());
Ok(service)
}
pub struct RunningSgGateway {
pub gateway_name: Arc<str>,
token: CancellationToken,
handle: tokio::task::JoinHandle<()>,
pub reloader: Reloader<HttpRouterService>,
shutdown_timeout: Duration,
}
impl std::fmt::Debug for RunningSgGateway {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RunningSgGateway").field("shutdown_timeout", &self.shutdown_timeout).finish()
}
}
pub static GLOBAL_STORE: OnceLock<Arc<Mutex<HashMap<String, RunningSgGateway>>>> = OnceLock::new();
impl RunningSgGateway {
pub async fn global_init(config: Config, signal: CancellationToken) {
for (id, spec) in config.plugins.into_inner() {
if let Err(err) = PluginRepository::global().create_or_update_instance(PluginConfig { id: id.clone(), spec }) {
tracing::error!("[SG.Config] fail to init plugin [{id}]: {err}", id = id.to_string());
}
}
for (name, item) in config.gateways {
match RunningSgGateway::create(item, signal.child_token()) {
Ok(inst) => RunningSgGateway::global_save(name, inst),
Err(e) => {
tracing::error!("[SG.Config] fail to init gateway [{name}]: {e}")
}
}
}
}
pub async fn global_reset() {
let store = Self::global_store();
let mut task = tokio::task::JoinSet::new();
{
let mut g_store = store.lock().expect("poisoned lock");
for (_, s) in g_store.drain() {
task.spawn(s.shutdown());
}
}
while let Some(res) = task.join_next().await {
res.expect("tokio join error")
}
PluginRepository::global().clear_instances()
}
pub fn global_store() -> Arc<Mutex<HashMap<String, RunningSgGateway>>> {
GLOBAL_STORE.get_or_init(Default::default).clone()
}
pub fn global_save(gateway_name: impl Into<String>, gateway: RunningSgGateway) {
let global_store = Self::global_store();
let mut global_store = global_store.lock().expect("poisoned lock");
global_store.insert(gateway_name.into(), gateway);
}
pub fn global_remove(gateway_name: impl AsRef<str>) -> Option<RunningSgGateway> {
let global_store = Self::global_store();
let mut global_store = global_store.lock().expect("poisoned lock");
global_store.remove(gateway_name.as_ref())
}
pub fn global_update(gateway_name: impl AsRef<str>, http_routes: BTreeMap<String, crate::SgHttpRoute>) -> Result<(), BoxError> {
let gateway_name = gateway_name.as_ref();
let service = create_router_service(gateway_name.to_string().into(), http_routes)?;
let reloader = {
let store = Self::global_store();
let global_store = store.lock().expect("poisoned lock");
if let Some(gw) = global_store.get(gateway_name) {
gw.reloader.clone()
} else {
warn!("no such gateway in global repository: {gateway_name}");
return Ok(());
}
};
reloader.reload(service);
Ok(())
}
#[instrument(fields(gateway=%config_item.gateway.name), skip_all, err)]
pub fn create(config_item: ConfigItem, cancel_token: CancellationToken) -> Result<Self, BoxError> {
#[allow(unused_mut)]
#[cfg(feature = "cache")]
{
if let Some(url) = &config_item.gateway.parameters.redis_url {
let url: Arc<str> = url.clone().into();
tracing::trace!("Initialize cache client...url:{url}");
spacegate_ext_redis::RedisClientRepo::global().add(&config_item.gateway.name, url.as_ref());
}
}
tracing::info!("[SG.Server] start gateway");
let reloader = <Reloader<HttpRouterService>>::default();
let gateway = config_item.gateway.clone();
let service = create_service(config_item, reloader.clone())?;
if gateway.listeners.is_empty() {
error!("[SG.Server] Missing Listeners");
}
let gateway_name: Arc<str> = Arc::from(gateway.name.to_string());
let mut listens: Vec<SgListen> = Vec::new();
for listener in &gateway.listeners {
let ip = listener.ip.unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
let addr = SocketAddr::new(ip, listener.port);
let mut listen = SgListen::new(addr, cancel_token.child_token());
if let SgProtocolConfig::Https { ref tls } = listener.protocol {
tracing::debug!("[SG.Server] Tls is init...mode:{:?}", tls.mode);
if SgTlsMode::Terminate == tls.mode {
{
let certs = rustls_pemfile::certs(&mut tls.cert.as_bytes()).filter_map(Result::ok).collect::<Vec<_>>();
let mut tls_key = tls.key.as_bytes();
let mut keys = rustls_pemfile::read_all(&mut tls_key).filter_map(Result::ok);
let key = keys.find_map(|key| {
debug!("key item: {:?}", key);
match key {
rustls_pemfile::Item::Pkcs1Key(k) => Some(PrivateKeyDer::Pkcs1(k)),
rustls_pemfile::Item::Pkcs8Key(k) => Some(PrivateKeyDer::Pkcs8(k)),
rustls_pemfile::Item::Sec1Key(k) => Some(PrivateKeyDer::Sec1(k)),
rest => {
warn!("Unsupported key type: {:?}", rest);
None
}
}
});
if let Some(key) = key {
info!("[SG.Server] using cert key {key:?}");
let _ = rustls::crypto::ring::default_provider().install_default();
let builder = rustls::ServerConfig::builder().with_no_client_auth();
let mut tls_server_cfg = if let Some(ref host_name) = listener.hostname {
info!("Using SNI resolver");
let mut resolver = rustls::server::ResolvesServerCertUsingSni::new();
let provider = rustls::crypto::CryptoProvider::get_default().expect("should installed");
let signed_key = provider.key_provider.load_private_key(key)?;
let ck = rustls::sign::CertifiedKey::new(certs, signed_key);
resolver.add(host_name, ck)?;
builder.with_cert_resolver(Arc::new(resolver))
} else {
info!("Using single cert");
builder.with_single_cert(certs, key)?
};
tls_server_cfg.alpn_protocols = vec![b"http/1.1".to_vec(), b"h2".to_vec()];
tls_server_cfg.ignore_client_order = true;
tls_server_cfg.enable_secret_extraction = true;
listen.add_service(service.clone().https(tls_server_cfg))
} else {
error!("[SG.Server] Can not found a valid Tls private key");
}
};
}
} else {
listen.add_service(service.clone().http());
}
listens.push(listen)
}
let cancel_task = cancel_token.clone().cancelled_owned();
let handle = {
let gateway_name = gateway_name.clone();
tokio::task::spawn(async move {
let mut join_set = tokio::task::JoinSet::new();
for listen in listens {
join_set.spawn(async move {
let id = listen.listener_id.clone();
if let Err(e) = listen.listen().await {
tracing::error!("[Sg.Server] listen error: {e}")
}
tracing::info!("[Sg.Server] listener[{id}] quit listening")
});
}
tracing::info!(gateway = gateway_name.as_ref(), "[Sg.Server] start all listeners");
cancel_task.await;
while let Some(result) = join_set.join_next().await {
if let Err(_e) = result {}
}
tracing::info!(gateway = gateway_name.as_ref(), "[Sg.Server] cancelled");
})
};
tracing::info!("[SG.Server] start finished");
Ok(RunningSgGateway {
gateway_name: gateway_name.clone(),
token: cancel_token,
handle,
shutdown_timeout: Duration::from_secs(10),
reloader,
})
}
pub async fn shutdown(self) {
self.token.cancel();
#[cfg(feature = "cache")]
{
let name = self.gateway_name.clone();
tracing::trace!("[SG.Cache] Remove cache client...");
spacegate_ext_redis::global_repo().remove(name.as_ref());
}
match timeout(self.shutdown_timeout, self.handle).await {
Ok(_) => {}
Err(e) => {
tracing::warn!("[SG.Server] Wait shutdown timeout:{e}");
}
};
tracing::info!("[SG.Server] Gateway shutdown");
}
}