selium-net-hyper 1.0.0-alpha.3

Streaming compute fabric
Documentation
//! Server-side HTTP helpers for the Hyper driver.

use std::{convert::Infallible, net::SocketAddr, sync::Arc};

use hyper::{
    Response, StatusCode,
    body::Incoming,
    server::conn::{http1, http2},
    service::service_fn,
};
use hyper_util::rt::{TokioExecutor, TokioIo};
use rustls::ServerConfig;
use selium_abi::NetProtocol;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tokio_rustls::TlsAcceptor;
use tracing::warn;

use crate::{
    driver::{HyperBody, HyperError, HyperStream, InboundState, PendingRequest},
    wire::{error_response, format_request_bytes, host_matches, parse_response},
};

pub(crate) fn read_inbound(
    state: &InboundState,
    len: usize,
) -> Result<selium_abi::IoFrame, HyperError> {
    let mut guard = state.request.lock().map_err(|_| HyperError::Lock)?;
    if guard.is_empty() {
        return Ok(selium_abi::IoFrame {
            writer_id: 0,
            payload: Vec::new(),
        });
    }

    let take = len.min(guard.len());
    let payload: Vec<u8> = guard.drain(..take).collect();
    Ok(selium_abi::IoFrame {
        writer_id: 0,
        payload,
    })
}

pub(crate) fn write_inbound(state: &InboundState, bytes: &[u8]) -> Result<(), HyperError> {
    let mut guard = state.response.lock().map_err(|_| HyperError::Lock)?;
    guard.extend_from_slice(bytes);
    let mut responder_guard = state.responder.lock().map_err(|_| HyperError::Lock)?;
    let ready = if responder_guard.is_some() {
        match parse_response(state.protocol, guard.as_slice()) {
            Ok(_) => true,
            Err(HyperError::HttpIncomplete) => false,
            Err(_) => true,
        }
    } else {
        false
    };
    if !ready {
        return Ok(());
    }

    let response_bytes = std::mem::take(&mut *guard);
    let responder = responder_guard.take();
    drop(guard);
    drop(responder_guard);
    if let Some(responder) = responder
        && responder.send(response_bytes).is_err()
    {
        tracing::debug!("response receiver dropped before completion");
    }
    Ok(())
}

pub(crate) async fn run_listener(
    listener: TcpListener,
    protocol: NetProtocol,
    domain: String,
    server_config: Arc<ServerConfig>,
    sender: mpsc::Sender<PendingRequest>,
) {
    loop {
        let (stream, remote_addr) = match listener.accept().await {
            Ok(pair) => pair,
            Err(err) => {
                warn!(err = %err, "HTTP listener accept failed");
                continue;
            }
        };

        let sender = sender.clone();
        let domain = domain.clone();
        let server_config = Arc::clone(&server_config);
        tokio::spawn(async move {
            if let Err(err) =
                serve_connection(stream, remote_addr, protocol, domain, server_config, sender).await
            {
                warn!(err = %err, "HTTP connection handler failed");
            }
        });
    }
}

async fn serve_connection(
    stream: tokio::net::TcpStream,
    remote_addr: SocketAddr,
    protocol: NetProtocol,
    domain: String,
    server_config: Arc<ServerConfig>,
    sender: mpsc::Sender<PendingRequest>,
) -> Result<(), HyperError> {
    let io: HyperStream = match protocol {
        NetProtocol::Http => Box::new(stream),
        NetProtocol::Https => {
            let acceptor = TlsAcceptor::from(server_config);
            let tls_stream = acceptor.accept(stream).await.map_err(HyperError::Tls)?;
            Box::new(tls_stream)
        }
        _ => return Err(HyperError::UnsupportedProtocol { protocol }),
    };

    let service = service_fn(move |req| {
        let sender = sender.clone();
        let domain = domain.clone();
        async move {
            let response =
                match handle_incoming_request(protocol, req, &domain, &sender, remote_addr).await {
                    Ok(response) => response,
                    Err(err) => {
                        warn!(err = %err, "HTTP request handling failed");
                        error_response(
                            protocol,
                            StatusCode::INTERNAL_SERVER_ERROR,
                            "request failed",
                        )
                    }
                };
            Ok::<_, Infallible>(response)
        }
    });

    match protocol {
        NetProtocol::Http => {
            let io = TokioIo::new(io);
            http1::Builder::new()
                .serve_connection(io, service)
                .await
                .map_err(HyperError::Hyper)
        }
        NetProtocol::Https => {
            let io = TokioIo::new(io);
            http2::Builder::new(TokioExecutor::new())
                .serve_connection(io, service)
                .await
                .map_err(HyperError::Hyper)
        }
        _ => Err(HyperError::UnsupportedProtocol { protocol }),
    }
}

async fn handle_incoming_request(
    protocol: NetProtocol,
    request: hyper::Request<Incoming>,
    domain: &str,
    sender: &mpsc::Sender<PendingRequest>,
    remote_addr: SocketAddr,
) -> Result<Response<HyperBody>, HyperError> {
    if !domain.is_empty() {
        let host = request
            .headers()
            .get(hyper::header::HOST)
            .and_then(|value| value.to_str().ok());
        match host {
            Some(host) if host_matches(domain, host) => {}
            Some(_) => {
                return Ok(error_response(
                    protocol,
                    StatusCode::MISDIRECTED_REQUEST,
                    "host mismatch",
                ));
            }
            None => {
                return Ok(error_response(
                    protocol,
                    StatusCode::BAD_REQUEST,
                    "missing host",
                ));
            }
        }
    }

    let request_bytes = format_request_bytes(request, protocol).await?;
    let (tx, rx) = tokio::sync::oneshot::channel();
    sender
        .send(PendingRequest {
            request_bytes,
            responder: tx,
            remote_addr: remote_addr.to_string(),
        })
        .await
        .map_err(|_| HyperError::ListenerClosed)?;

    let response_bytes = rx.await.map_err(|_| HyperError::ResponseChannelClosed)?;
    match parse_response(protocol, &response_bytes) {
        Ok(response) => Ok(response),
        Err(err) => {
            warn!(err = %err, "invalid HTTP response from guest");
            Ok(error_response(
                protocol,
                StatusCode::BAD_GATEWAY,
                "invalid response",
            ))
        }
    }
}