wasmcloud-provider-http-server 0.27.1

Http server for wasmcloud, using Axum. This package provides a library, and a capability provider
Documentation
//! Implementation of the `wrpc:http/incoming-handler` provider in address mode
//!
//! This provider listens on a new address for each component that it links to.

use core::str::FromStr as _;
use core::time::Duration;

use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;

use anyhow::{bail, Context as _};
use axum::extract;
use axum::handler::Handler;
use axum_server::tls_rustls::RustlsConfig;
use tokio::sync::RwLock;
use tracing::{debug, error, info, instrument};
use wasmcloud_core::http::{default_listen_address, load_settings, ServiceSettings};
use wasmcloud_provider_sdk::core::LinkName;
use wasmcloud_provider_sdk::provider::WrpcClient;
use wasmcloud_provider_sdk::{get_connection, HostData, LinkConfig, LinkDeleteInfo, Provider};

use crate::{build_request, get_cors_layer, get_tcp_listener, invoke_component};

/// Lookup for handlers by socket
///
/// Indexed first by socket address to more easily detect duplicates,
/// with the http server stored, along with a list (order matters) of components that were registered
type HandlerLookup =
    HashMap<SocketAddr, (Arc<HttpServerCore>, Vec<(Arc<str>, Arc<str>, WrpcClient)>)>;

/// `wrpc:http/incoming-handler` provider implementation in address mode
#[derive(Clone)]
pub struct HttpServerProvider {
    default_address: SocketAddr,

    /// Lookup of components that handle requests {addr -> (server, (component id, link name))}
    handlers_by_socket: Arc<RwLock<HandlerLookup>>,

    /// Sockets that are relevant to a given link name
    ///
    /// This structure is generally used as a look up into `handlers_by_socket`
    sockets_by_link_name: Arc<RwLock<HashMap<LinkName, SocketAddr>>>,
}

impl Default for HttpServerProvider {
    fn default() -> Self {
        Self {
            default_address: default_listen_address(),
            handlers_by_socket: Arc::default(),
            sockets_by_link_name: Arc::default(),
        }
    }
}

impl HttpServerProvider {
    /// Create a new instance of the HTTP server provider
    pub fn new(host_data: &HostData) -> anyhow::Result<Self> {
        let default_address = host_data
            .config
            .get("default_address")
            .map(|s| SocketAddr::from_str(s))
            .transpose()
            .context("failed to parse default_address")?
            .unwrap_or_else(default_listen_address);

        Ok(Self {
            default_address,
            handlers_by_socket: Arc::default(),
            sockets_by_link_name: Arc::default(),
        })
    }
}

impl Provider for HttpServerProvider {
    /// This is called when the HTTP server provider is linked to a component
    ///
    /// This HTTP server mode will listen on a new address for each component that it links to.
    async fn receive_link_config_as_source(
        &self,
        link_config: LinkConfig<'_>,
    ) -> anyhow::Result<()> {
        let settings = match load_settings(Some(self.default_address), link_config.config)
            .context("httpserver failed to load settings for component")
        {
            Ok(settings) => settings,
            Err(e) => {
                error!(
                    config = ?link_config.config,
                    "httpserver failed to load settings for component: {}", e.to_string()
                );
                bail!(e);
            }
        };

        let wrpc = get_connection()
            .get_wrpc_client(link_config.target_id)
            .await
            .context("failed to construct wRPC client")?;
        let component_meta = (
            Arc::from(link_config.target_id),
            Arc::from(link_config.link_name),
            wrpc,
        );
        let mut sockets_by_link_name = self.sockets_by_link_name.write().await;
        let mut handlers_by_socket = self.handlers_by_socket.write().await;

        match sockets_by_link_name.entry(link_config.link_name.to_string()) {
            // If a mapping already exists, and the stored address is different, disallow overwriting
            std::collections::hash_map::Entry::Occupied(v) => {
                bail!(
                    "an address mapping for address [{}] the link [{}] already exists, overwriting links is not currently supported",
                    v.get().ip().to_string(),
                    link_config.link_name,
                )
            }
            // If a mapping does exist, we can create a new mapping for the address
            std::collections::hash_map::Entry::Vacant(v) => {
                v.insert(settings.address);
            }
        }

        match handlers_by_socket.entry(settings.address) {
            // If handlers already exist for the address, add the newly linked component
            //
            // NOTE: only components at the head of the list are served requests
            std::collections::hash_map::Entry::Occupied(mut v) => {
                v.get_mut().1.push(component_meta);
            }
            // If a handler does not already exist, make a new server and insert
            std::collections::hash_map::Entry::Vacant(v) => {
                // Start a server instance that calls the given component
                let http_server = HttpServerCore::new(
                    Arc::new(settings),
                    link_config.target_id,
                    self.handlers_by_socket.clone(),
                )
                .await
                .context("httpserver failed to start listener for component")?;
                v.insert((Arc::new(http_server), vec![component_meta]));
            }
        }

        Ok(())
    }

    /// Handle notification that a link is dropped - stop the http listener
    #[instrument(level = "info", skip_all, fields(target_id = info.get_target_id()))]
    async fn delete_link_as_source(&self, info: impl LinkDeleteInfo) -> anyhow::Result<()> {
        let component_id = info.get_target_id();
        let link_name = info.get_link_name();

        // Retrieve the thing by link name
        let mut sockets_by_link_name = self.sockets_by_link_name.write().await;
        if let Some(addr) = sockets_by_link_name.get(link_name) {
            let mut handlers_by_socket = self.handlers_by_socket.write().await;
            if let Some((server, component_metas)) = handlers_by_socket.get_mut(addr) {
                // If the component id & link name pair is present, remove it
                if let Some(idx) = component_metas
                    .iter()
                    .position(|(c, l, ..)| c.as_ref() == component_id && l.as_ref() == link_name)
                {
                    component_metas.remove(idx);
                }

                // If the component was the last one, we can remove the server
                if component_metas.is_empty() {
                    info!(
                        address = addr.to_string(),
                        "last component removed for address, shutting down server"
                    );
                    server.handle.shutdown();
                    handlers_by_socket.remove(addr);
                    sockets_by_link_name.remove(link_name);
                }
            }
        }

        Ok(())
    }

    /// Handle shutdown request by shutting down all the http server threads
    async fn shutdown(&self) -> anyhow::Result<()> {
        // Empty the component link data and stop all servers
        self.sockets_by_link_name.write().await.clear();
        self.handlers_by_socket.write().await.clear();
        Ok(())
    }
}

#[derive(Clone)]
struct RequestContext {
    /// Address of the server, used for handler lookup
    server_address: SocketAddr,
    /// Settings that can be
    settings: Arc<ServiceSettings>,
    /// HTTP scheme
    scheme: http::uri::Scheme,
    /// Handlers for components
    handlers_by_socket: Arc<RwLock<HandlerLookup>>,
}

/// Handle an HTTP request by invoking the target component as configured in the listener
#[instrument(level = "debug", skip(settings, handlers_by_socket))]
async fn handle_request(
    extract::State(RequestContext {
        server_address,
        settings,
        scheme,
        handlers_by_socket,
    }): extract::State<RequestContext>,
    axum_extra::extract::Host(authority): axum_extra::extract::Host,
    request: extract::Request,
) -> impl axum::response::IntoResponse {
    let (component_id, wrpc) = {
        let Some((component_id, wrpc)) = handlers_by_socket
            .read()
            .await
            .get(&server_address)
            .and_then(|v| v.1.first())
            .map(|(component_id, _, wrpc)| (Arc::clone(component_id), wrpc.clone()))
        else {
            return Err((
                http::StatusCode::INTERNAL_SERVER_ERROR,
                "no targets for HTTP request",
            ))?;
        };
        (component_id, wrpc)
    };

    let timeout = settings.timeout_ms.map(Duration::from_millis);
    let req = build_request(request, scheme, authority, &settings)?;
    axum::response::Result::<_, axum::response::ErrorResponse>::Ok(
        invoke_component(
            &wrpc,
            &component_id,
            req,
            timeout,
            settings.cache_control.as_ref(),
        )
        .await,
    )
}

/// An asynchronous `wrpc:http/incoming-handler` with support for CORS and TLS
#[derive(Debug)]
pub struct HttpServerCore {
    /// The handle to the server handling incoming requests
    handle: axum_server::Handle,
    /// The asynchronous task running the server
    task: tokio::task::JoinHandle<()>,
}

impl HttpServerCore {
    #[instrument(skip(handlers_by_socket))]
    pub async fn new(
        settings: Arc<ServiceSettings>,
        target: &str,
        handlers_by_socket: Arc<RwLock<HandlerLookup>>,
    ) -> anyhow::Result<Self> {
        let addr = settings.address;
        info!(
            %addr,
            component_id = target,
            "httpserver starting listener for target",
        );
        let cors = get_cors_layer(&settings)?;
        let service = handle_request.layer(cors);
        let handle = axum_server::Handle::new();
        let listener = get_tcp_listener(&settings)
            .with_context(|| format!("failed to create listener (is [{addr}] already in use?)"))?;

        let target = target.to_owned();
        let task_handle = handle.clone();
        let task = if let (Some(crt), Some(key)) =
            (&settings.tls_cert_file, &settings.tls_priv_key_file)
        {
            debug!(?addr, "bind HTTPS listener");
            let tls = RustlsConfig::from_pem_file(crt, key)
                .await
                .context("failed to construct TLS config")?;

            let srv = axum_server::from_tcp_rustls(listener, tls);
            tokio::spawn(async move {
                if let Err(e) = srv
                    .handle(task_handle)
                    .serve(
                        service
                            .with_state(RequestContext {
                                server_address: addr,
                                settings,
                                scheme: http::uri::Scheme::HTTPS,
                                handlers_by_socket,
                            })
                            .into_make_service(),
                    )
                    .await
                {
                    error!(error = %e, component_id = target, "failed to serve HTTPS for component");
                }
            })
        } else {
            debug!(?addr, "bind HTTP listener");

            let mut srv = axum_server::from_tcp(listener);
            srv.http_builder().http1().keep_alive(false);
            tokio::spawn(async move {
                if let Err(e) = srv
                    .handle(task_handle)
                    .serve(
                        service
                            .with_state(RequestContext {
                                server_address: addr,
                                settings,
                                scheme: http::uri::Scheme::HTTP,
                                handlers_by_socket,
                            })
                            .into_make_service(),
                    )
                    .await
                {
                    error!(error = %e, component_id = target, "failed to serve HTTP for component");
                }
            })
        };

        Ok(Self { handle, task })
    }
}

impl Drop for HttpServerCore {
    /// Drop the client connection. Does not block or fail if the client has already been closed.
    fn drop(&mut self) {
        self.handle.shutdown();
        self.task.abort();
    }
}