forte-sdk 0.3.4

Runtime library for Forte wasm components (wasi:http p3)
Documentation
use std::fmt;
use std::future::Future;

use tracing::Instrument;
use wit_bindgen::rt::async_support::StreamReader;

use crate::bindings::wasi::http::types as p3;
use crate::bindings::{wit_future, wit_stream};
use crate::http::Body;

#[derive(Debug)]
pub enum ServeError {
    InvalidUri(http::uri::InvalidUri),
    BuildRequest(http::Error),
    Headers(p3::HeaderError),
    InvalidStatusCode,
}

impl fmt::Display for ServeError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            ServeError::InvalidUri(e) => write!(f, "invalid uri: {e}"),
            ServeError::BuildRequest(e) => write!(f, "failed to build http::Request: {e}"),
            ServeError::Headers(e) => write!(f, "invalid response headers: {e:?}"),
            ServeError::InvalidStatusCode => write!(f, "invalid status code"),
        }
    }
}

impl std::error::Error for ServeError {
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        match self {
            ServeError::InvalidUri(e) => Some(e),
            ServeError::BuildRequest(e) => Some(e),
            _ => None,
        }
    }
}

pub async fn serve<F, Fut, E>(
    req: p3::Request,
    dispatch: F,
) -> core::result::Result<p3::Response, p3::ErrorCode>
where
    F: FnOnce(::http::Request<Vec<u8>>) -> Fut,
    Fut: Future<Output = core::result::Result<::http::Response<Body>, E>>,
    E: fmt::Debug,
{
    crate::otel::init_once();

    let http_req = match p3_to_http_request(req).await {
        Ok(r) => r,
        Err(e) => return Err(p3::ErrorCode::InternalError(Some(format!("{e}")))),
    };

    let method_str = http_req.method().as_str().to_string();
    let path_str = http_req.uri().path().to_string();
    let span = tracing::info_span!(
        "http.request",
        otel.name = %format!("{} {}", method_str, path_str),
        http.request.method = %method_str,
        url.path = %path_str,
        http.response.status_code = tracing::field::Empty,
    );

    let http_resp = match dispatch(http_req).instrument(span.clone()).await {
        Ok(r) => r,
        Err(e) => return Err(p3::ErrorCode::InternalError(Some(format!("{e:?}")))),
    };

    span.record("http.response.status_code", http_resp.status().as_u16());

    http_response_to_p3(http_resp)
        .await
        .map_err(|e| p3::ErrorCode::InternalError(Some(format!("{e}"))))
}

async fn p3_to_http_request(
    req: p3::Request,
) -> core::result::Result<::http::Request<Vec<u8>>, ServeError> {
    let method = method_from_p3(req.get_method().await);

    let scheme_str = match req.get_scheme().await {
        Some(p3::Scheme::Http) => "http".to_string(),
        Some(p3::Scheme::Https) => "https".to_string(),
        Some(p3::Scheme::Other(s)) => s,
        None => "http".to_string(),
    };
    let path_with_query = req
        .get_path_with_query()
        .await
        .unwrap_or_else(|| "/".into());

    let wasi_headers = req.get_headers().await;
    let header_list = wasi_headers.copy_all().await;
    drop(wasi_headers);

    let authority = match req.get_authority().await {
        Some(a) if !a.is_empty() => a,
        _ => header_list
            .iter()
            .find(|(name, _)| name.eq_ignore_ascii_case("host"))
            .and_then(|(_, value)| std::str::from_utf8(value).ok().map(str::to_owned))
            .unwrap_or_default(),
    };

    let uri: ::http::Uri = format!("{scheme_str}://{authority}{path_with_query}")
        .parse()
        .map_err(ServeError::InvalidUri)?;

    let mut builder = ::http::Request::builder().method(method).uri(uri);
    for (name, value) in header_list {
        builder = builder.header(name, value);
    }

    let (trailers_writer, trailers_reader) =
        wit_future::new::<core::result::Result<(), p3::ErrorCode>>(|| Ok(()));
    crate::runtime::spawn(async move {
        drop(trailers_writer);
    });
    let (body_stream, _resp_trailers) = p3::Request::consume_body(req, trailers_reader).await;
    let body_bytes: Vec<u8> = collect_stream(body_stream).await;

    builder.body(body_bytes).map_err(ServeError::BuildRequest)
}

async fn http_response_to_p3(
    resp: ::http::Response<Body>,
) -> core::result::Result<p3::Response, ServeError> {
    let (parts, body) = resp.into_parts();

    let header_entries: Vec<(String, Vec<u8>)> = parts
        .headers
        .iter()
        .map(|(name, value)| (name.as_str().to_string(), value.as_bytes().to_vec()))
        .collect();
    let fields = p3::Fields::from_list(header_entries)
        .await
        .map_err(ServeError::Headers)?;

    let body_bytes = match body {
        Body::Empty => None,
        Body::Bytes(b) if b.is_empty() => None,
        Body::Bytes(b) => Some(b),
        Body::Stream(reader) => Some(reader.collect().await),
    };

    let contents_reader = if let Some(bytes) = body_bytes {
        let (mut writer, reader) = wit_stream::new::<u8>();
        crate::runtime::spawn(async move {
            let _leftover = writer.write_all(bytes).await;
            drop(writer);
        });
        Some(reader)
    } else {
        None
    };

    let (trailers_writer, trailers_reader) =
        wit_future::new::<core::result::Result<Option<p3::Trailers>, p3::ErrorCode>>(|| Ok(None));
    crate::runtime::spawn(async move {
        drop(trailers_writer);
    });

    let (wasi_resp, _transmit) = p3::Response::new(fields, contents_reader, trailers_reader).await;
    wasi_resp
        .set_status_code(parts.status.as_u16())
        .await
        .map_err(|_| ServeError::InvalidStatusCode)?;

    Ok(wasi_resp)
}

async fn collect_stream(stream: StreamReader<u8>) -> Vec<u8> {
    stream.collect().await
}

fn method_from_p3(m: p3::Method) -> ::http::Method {
    match m {
        p3::Method::Get => ::http::Method::GET,
        p3::Method::Head => ::http::Method::HEAD,
        p3::Method::Post => ::http::Method::POST,
        p3::Method::Put => ::http::Method::PUT,
        p3::Method::Delete => ::http::Method::DELETE,
        p3::Method::Connect => ::http::Method::CONNECT,
        p3::Method::Options => ::http::Method::OPTIONS,
        p3::Method::Trace => ::http::Method::TRACE,
        p3::Method::Patch => ::http::Method::PATCH,
        p3::Method::Other(s) => {
            ::http::Method::from_bytes(s.as_bytes()).unwrap_or(::http::Method::GET)
        }
    }
}