use core::str::FromStr as _;
use core::time::Duration;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::{bail, Context as _};
use axum::extract;
use axum::handler::Handler;
use axum_server::tls_rustls::RustlsConfig;
use tokio::sync::RwLock;
use tracing::{debug, error, info, instrument};
use wasmcloud_core::http::{default_listen_address, load_settings, ServiceSettings};
use wasmcloud_provider_sdk::core::LinkName;
use wasmcloud_provider_sdk::provider::WrpcClient;
use wasmcloud_provider_sdk::{get_connection, HostData, LinkConfig, LinkDeleteInfo, Provider};
use crate::{build_request, get_cors_layer, get_tcp_listener, invoke_component};
type HandlerLookup =
HashMap<SocketAddr, (Arc<HttpServerCore>, Vec<(Arc<str>, Arc<str>, WrpcClient)>)>;
#[derive(Clone)]
pub struct HttpServerProvider {
default_address: SocketAddr,
handlers_by_socket: Arc<RwLock<HandlerLookup>>,
sockets_by_link_name: Arc<RwLock<HashMap<LinkName, SocketAddr>>>,
}
impl Default for HttpServerProvider {
fn default() -> Self {
Self {
default_address: default_listen_address(),
handlers_by_socket: Arc::default(),
sockets_by_link_name: Arc::default(),
}
}
}
impl HttpServerProvider {
pub fn new(host_data: &HostData) -> anyhow::Result<Self> {
let default_address = host_data
.config
.get("default_address")
.map(|s| SocketAddr::from_str(s))
.transpose()
.context("failed to parse default_address")?
.unwrap_or_else(default_listen_address);
Ok(Self {
default_address,
handlers_by_socket: Arc::default(),
sockets_by_link_name: Arc::default(),
})
}
}
impl Provider for HttpServerProvider {
async fn receive_link_config_as_source(
&self,
link_config: LinkConfig<'_>,
) -> anyhow::Result<()> {
let settings = match load_settings(Some(self.default_address), link_config.config)
.context("httpserver failed to load settings for component")
{
Ok(settings) => settings,
Err(e) => {
error!(
config = ?link_config.config,
"httpserver failed to load settings for component: {}", e.to_string()
);
bail!(e);
}
};
let wrpc = get_connection()
.get_wrpc_client(link_config.target_id)
.await
.context("failed to construct wRPC client")?;
let component_meta = (
Arc::from(link_config.target_id),
Arc::from(link_config.link_name),
wrpc,
);
let mut sockets_by_link_name = self.sockets_by_link_name.write().await;
let mut handlers_by_socket = self.handlers_by_socket.write().await;
match sockets_by_link_name.entry(link_config.link_name.to_string()) {
std::collections::hash_map::Entry::Occupied(v) => {
bail!(
"an address mapping for address [{}] the link [{}] already exists, overwriting links is not currently supported",
v.get().ip().to_string(),
link_config.link_name,
)
}
std::collections::hash_map::Entry::Vacant(v) => {
v.insert(settings.address);
}
}
match handlers_by_socket.entry(settings.address) {
std::collections::hash_map::Entry::Occupied(mut v) => {
v.get_mut().1.push(component_meta);
}
std::collections::hash_map::Entry::Vacant(v) => {
let http_server = HttpServerCore::new(
Arc::new(settings),
link_config.target_id,
self.handlers_by_socket.clone(),
)
.await
.context("httpserver failed to start listener for component")?;
v.insert((Arc::new(http_server), vec![component_meta]));
}
}
Ok(())
}
#[instrument(level = "info", skip_all, fields(target_id = info.get_target_id()))]
async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
let component_id = info.get_target_id();
let link_name = info.get_link_name();
let mut sockets_by_link_name = self.sockets_by_link_name.write().await;
if let Some(addr) = sockets_by_link_name.get(link_name) {
let mut handlers_by_socket = self.handlers_by_socket.write().await;
if let Some((server, component_metas)) = handlers_by_socket.get_mut(addr) {
if let Some(idx) = component_metas
.iter()
.position(|(c, l, ..)| c.as_ref() == component_id && l.as_ref() == link_name)
{
component_metas.remove(idx);
}
if component_metas.is_empty() {
info!(
address = addr.to_string(),
"last component removed for address, shutting down server"
);
server.handle.shutdown();
handlers_by_socket.remove(addr);
sockets_by_link_name.remove(link_name);
}
}
}
Ok(())
}
async fn shutdown(&self) -> anyhow::Result<()> {
self.sockets_by_link_name.write().await.clear();
self.handlers_by_socket.write().await.clear();
Ok(())
}
}
#[derive(Clone)]
struct RequestContext {
server_address: SocketAddr,
settings: Arc<ServiceSettings>,
scheme: http::uri::Scheme,
handlers_by_socket: Arc<RwLock<HandlerLookup>>,
}
#[instrument(level = "debug", skip(settings, handlers_by_socket))]
async fn handle_request(
extract::State(RequestContext {
server_address,
settings,
scheme,
handlers_by_socket,
}): extract::State<RequestContext>,
axum_extra::extract::Host(authority): axum_extra::extract::Host,
request: extract::Request,
) -> impl axum::response::IntoResponse {
let (component_id, wrpc) = {
let Some((component_id, wrpc)) = handlers_by_socket
.read()
.await
.get(&server_address)
.and_then(|v| v.1.first())
.map(|(component_id, _, wrpc)| (Arc::clone(component_id), wrpc.clone()))
else {
return Err((
http::StatusCode::INTERNAL_SERVER_ERROR,
"no targets for HTTP request",
))?;
};
(component_id, wrpc)
};
let timeout = settings.timeout_ms.map(Duration::from_millis);
let req = build_request(request, scheme, authority, &settings)?;
axum::response::Result::<_, axum::response::ErrorResponse>::Ok(
invoke_component(
&wrpc,
&component_id,
req,
timeout,
settings.cache_control.as_ref(),
)
.await,
)
}
#[derive(Debug)]
pub struct HttpServerCore {
handle: axum_server::Handle,
task: tokio::task::JoinHandle<()>,
}
impl HttpServerCore {
#[instrument(skip(handlers_by_socket))]
pub async fn new(
settings: Arc<ServiceSettings>,
target: &str,
handlers_by_socket: Arc<RwLock<HandlerLookup>>,
) -> anyhow::Result<Self> {
let addr = settings.address;
info!(
%addr,
component_id = target,
"httpserver starting listener for target",
);
let cors = get_cors_layer(&settings)?;
let service = handle_request.layer(cors);
let handle = axum_server::Handle::new();
let listener = get_tcp_listener(&settings)
.with_context(|| format!("failed to create listener (is [{addr}] already in use?)"))?;
let target = target.to_owned();
let task_handle = handle.clone();
let task = if let (Some(crt), Some(key)) =
(&settings.tls_cert_file, &settings.tls_priv_key_file)
{
debug!(?addr, "bind HTTPS listener");
let tls = RustlsConfig::from_pem_file(crt, key)
.await
.context("failed to construct TLS config")?;
let srv = axum_server::from_tcp_rustls(listener, tls);
tokio::spawn(async move {
if let Err(e) = srv
.handle(task_handle)
.serve(
service
.with_state(RequestContext {
server_address: addr,
settings,
scheme: http::uri::Scheme::HTTPS,
handlers_by_socket,
})
.into_make_service(),
)
.await
{
error!(error = %e, component_id = target, "failed to serve HTTPS for component");
}
})
} else {
debug!(?addr, "bind HTTP listener");
let mut srv = axum_server::from_tcp(listener);
srv.http_builder().http1().keep_alive(false);
tokio::spawn(async move {
if let Err(e) = srv
.handle(task_handle)
.serve(
service
.with_state(RequestContext {
server_address: addr,
settings,
scheme: http::uri::Scheme::HTTP,
handlers_by_socket,
})
.into_make_service(),
)
.await
{
error!(error = %e, component_id = target, "failed to serve HTTP for component");
}
})
};
Ok(Self { handle, task })
}
}
impl Drop for HttpServerCore {
fn drop(&mut self) {
self.handle.shutdown();
self.task.abort();
}
}