use std::clone::Clone;
use std::convert::Infallible;
use std::env;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context as TaskContext, Poll};
use anyhow::{Context, Result, anyhow};
use bytes::Bytes;
use http::StatusCode;
use http::uri::{PathAndQuery, Uri};
use http_body_util::combinators::UnsyncBoxBody;
use http_body_util::{BodyExt, Full};
use hyper::body::{Body, Frame, Incoming, SizeHint};
use hyper::header::{FORWARDED, HOST};
use hyper::server::conn::http1;
use hyper::service::service_fn;
use omnia::State;
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tracing::{Instrument, debug_span};
use wasmtime::Store;
use wasmtime_wasi_http::io::TokioIo;
use wasmtime_wasi_http::p3::WasiHttpView;
use wasmtime_wasi_http::p3::bindings::ServiceIndices;
use wasmtime_wasi_http::p3::bindings::http::types::{self as wasi, ErrorCode};
type OutgoingBody = UnsyncBoxBody<Bytes, anyhow::Error>;
const HTTP_ADDR: &str = "0.0.0.0:8080";
pub async fn serve<S>(state: &S) -> Result<()>
where
S: State,
S::StoreCtx: WasiHttpView,
{
let component = env::var("COMPONENT").unwrap_or_else(|_| "unknown".into());
let addr = env::var("HTTP_ADDR").unwrap_or_else(|_| HTTP_ADDR.into());
let listener = TcpListener::bind(&addr).await?;
tracing::info!("{component} http server listening on: {addr}");
let handler = Handler {
state: Arc::new(state.clone()),
component,
};
loop {
let (stream, _) = listener.accept().await?;
stream.set_nodelay(true)?;
let stream = TokioIo::new(stream);
let handler = handler.clone();
tokio::spawn(async move {
let mut http1 = http1::Builder::new();
http1.keep_alive(true);
if let Err(e) = http1
.serve_connection(
stream,
service_fn(move |request| {
let handler = handler.clone();
async move {
let response = handler.handle(request).await.unwrap_or_else(|e| {
tracing::error!("Error proxying request: {e}");
internal_error()
});
if response.status() >= StatusCode::INTERNAL_SERVER_ERROR {
tracing::error!(
monotonic_counter.processing_errors = 1,
service = %handler.component,
error = format!("{response:?}"),
);
}
Ok::<_, Infallible>(response)
}
}),
)
.await
{
tracing::error!("connection error: {e:?}");
}
});
}
}
#[derive(Clone)]
struct Handler<S>
where
S: State,
S::StoreCtx: WasiHttpView,
{
state: Arc<S>,
component: String,
}
impl<S> Handler<S>
where
S: State,
S::StoreCtx: WasiHttpView,
{
async fn handle(
&self, request: hyper::Request<Incoming>,
) -> Result<hyper::Response<OutgoingBody>> {
tracing::debug!("handling request: {request:?}");
let request = fix_request(request).context("preparing request")?;
let instance_pre = self.state.instance_pre();
let store_data = self.state.store();
let mut store = Store::new(instance_pre.engine(), store_data);
let indices = ServiceIndices::new(instance_pre)?;
let instance = instance_pre.instantiate_async(&mut store).await?;
let service = indices.load(&mut store, &instance)?;
let (sender, receiver) = oneshot::channel::<Result<hyper::Response<OutgoingBody>>>();
tokio::spawn(async move {
let send_err = |sender: oneshot::Sender<_>, e: anyhow::Error| {
_ = sender.send(Err(e));
};
let result = store
.run_concurrent(async |store| {
let (parts, body) = request.into_parts();
let body = body.map_err(ErrorCode::from_hyper_request_error);
let http_req = http::Request::from_parts(parts, body);
let (request, io) = wasi::Request::from_http(http_req);
let wasi_resp = match service.handle(store, request).await? {
Ok(resp) => resp,
Err(e) => {
send_err(sender, anyhow!("guest error: {e}"));
return anyhow::Ok(());
}
};
let resp = match store.with(|mut store| wasi_resp.into_http(&mut store, io)) {
Ok(resp) => resp,
Err(e) => {
send_err(sender, anyhow!("converting guest response: {e}"));
return anyhow::Ok(());
}
};
let (body_done_tx, body_done_rx) = oneshot::channel::<()>();
let resp = resp.map(|body| {
BodyDoneWrapper {
body: body.map_err(Into::into),
_tx: body_done_tx,
}
.boxed_unsync()
});
if sender.send(Ok(resp)).is_ok() {
_ = body_done_rx.await;
}
anyhow::Ok(())
})
.instrument(debug_span!("http-request"))
.await;
match result {
Err(e) => tracing::error!("run_concurrent error: {e:?}"),
Ok(Err(e)) => tracing::error!("guest error: {e:#}"),
Ok(Ok(())) => {}
}
});
let response = receiver.await.map_err(|_recv| anyhow!("guest task panicked"))??;
tracing::debug!("received response: {response:?}");
Ok(response)
}
}
fn fix_request(mut request: hyper::Request<Incoming>) -> Result<hyper::Request<Incoming>> {
let uri = request.uri_mut();
let p_and_q = uri.path_and_query().map_or_else(|| PathAndQuery::from_static("/"), Clone::clone);
let mut uri_builder = Uri::builder().path_and_query(p_and_q);
if let Some(forwarded) = request.headers().get(FORWARDED) {
for tuple in forwarded.to_str()?.split(';') {
let tuple = tuple.trim();
if let Some(host) = tuple.strip_prefix("host=") {
uri_builder = uri_builder.authority(host);
} else if let Some(proto) = tuple.strip_prefix("proto=") {
uri_builder = uri_builder.scheme(proto);
}
}
} else {
let Some(host) = request.headers().get(HOST) else {
return Err(anyhow!("missing host header"));
};
uri_builder = uri_builder.authority(host.to_str()?);
uri_builder = uri_builder.scheme("http");
}
let (mut parts, body) = request.into_parts();
parts.uri = uri_builder.build()?;
let request = hyper::Request::from_parts(parts, body);
Ok(request)
}
struct BodyDoneWrapper<B> {
body: B,
_tx: oneshot::Sender<()>,
}
impl<B> Body for BodyDoneWrapper<B>
where
B: Body + Unpin,
{
type Data = B::Data;
type Error = B::Error;
fn poll_frame(
self: Pin<&mut Self>, cx: &mut TaskContext<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let inner = Pin::new(&mut self.get_mut().body);
inner.poll_frame(cx)
}
fn size_hint(&self) -> SizeHint {
self.body.size_hint()
}
fn is_end_stream(&self) -> bool {
self.body.is_end_stream()
}
}
const BODY: &str = r"<!doctype html>
<html>
<head>
<title>500 Internal Server Error</title>
</head>
<body>
<center>
<h1>500 Internal Server Error</h1>
<hr>
<pre>Guest error</pre>
</center>
</body>
</html>";
fn internal_error() -> hyper::Response<OutgoingBody> {
let body = Full::new(Bytes::from(BODY)).map_err(Into::into).boxed_unsync();
hyper::Response::builder()
.status(500)
.header("Content-Type", "text/html; charset=UTF-8")
.body(body)
.expect("should build internal error response")
}