fn0 0.2.14

FaaS platform powered by wasmtime
use crate::execute::ClientState;
use crate::measure_cpu_time::{Clock, SystemClock, TimeTracker, measure_cpu_time};
use crate::{Request, Response, telemetry};
use anyhow::{Result, anyhow};
use bytes::Bytes;
use http_body_util::BodyExt;
use http_body_util::combinators::UnsyncBoxBody;
use hyper::http;
use std::cell::Cell;
use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use wasmtime::AsContextMut;
use wasmtime::component::Accessor;
use wasmtime_wasi::TrappableError;
use wasmtime_wasi_http::p3::Request as P3Request;
use wasmtime_wasi_http::p3::bindings::Service;
use wasmtime_wasi_http::p3::bindings::http::types::ErrorCode;
use wasmtime_wasi_http::p3::{RequestOptions, WasiHttpHooks, default_send_request};

pub async fn call_wasm_direct(req: Request) -> Result<Response> {
    let acc_ptr = ACCESSOR_PTR.with(|c| c.get());
    let svc_ptr = SERVICE_PTR.with(|c| c.get());
    let (Some(acc_ptr), Some(svc_ptr)) = (acc_ptr, svc_ptr) else {
        return Err(anyhow!("no accessor installed in current thread"));
    };
    let accessor: &Accessor<ClientState<SystemClock>> = unsafe { &*acc_ptr };
    let service: &Service = unsafe { &*svc_ptr };

    let (time_tracker, is_timeout, code_id) = accessor.with(|mut access| {
        let state = access.data_mut();
        (
            state.time_tracker.clone(),
            state.is_timeout.clone(),
            state.code_id.clone(),
        )
    });

    let req_http = req.map(|body| {
        body.map_err(|err| ErrorCode::InternalError(Some(err.to_string())))
            .boxed_unsync()
    });
    let (p3_req, req_io) = P3Request::from_http(req_http);
    call_service(
        accessor,
        service,
        p3_req,
        req_io,
        &code_id,
        time_tracker,
        &is_timeout,
    )
    .await
}

thread_local! {
    static ACCESSOR_PTR: Cell<Option<*const Accessor<ClientState<SystemClock>>>> = const { Cell::new(None) };
    static SERVICE_PTR: Cell<Option<*const Service>> = const { Cell::new(None) };
}

tokio::task_local! {
    pub(crate) static SELF_HOST: String;
}

pub(crate) struct AccessorGuard;

impl AccessorGuard {
    pub(crate) fn install(
        accessor: &Accessor<ClientState<SystemClock>>,
        service: &Service,
    ) -> Self {
        ACCESSOR_PTR.with(|c| c.set(Some(accessor as *const _)));
        SERVICE_PTR.with(|c| c.set(Some(service as *const _)));
        Self
    }
}

impl Drop for AccessorGuard {
    fn drop(&mut self) {
        ACCESSOR_PTR.with(|c| c.set(None));
        SERVICE_PTR.with(|c| c.set(None));
    }
}

pub(crate) fn extract_host(headers: &hyper::HeaderMap) -> Option<String> {
    let value = headers.get(hyper::header::HOST)?;
    let s = value.to_str().ok()?;
    Some(normalize_host(s))
}

fn normalize_host(host: &str) -> String {
    host.split(':').next().unwrap_or(host).to_ascii_lowercase()
}

fn matches_self(uri: &http::Uri, self_host: &str) -> bool {
    let Some(host) = uri.host() else { return false };
    host.eq_ignore_ascii_case(self_host)
}

type HookResponse = (
    http::Response<UnsyncBoxBody<Bytes, ErrorCode>>,
    Box<dyn Future<Output = std::result::Result<(), ErrorCode>> + Send>,
);
type HookResult = std::result::Result<HookResponse, TrappableError<ErrorCode>>;

pub(crate) struct SelfInvokeHooks;

impl SelfInvokeHooks {
    pub(crate) fn new() -> Self {
        Self
    }
}

impl WasiHttpHooks for SelfInvokeHooks {
    fn send_request(
        &mut self,
        request: http::Request<UnsyncBoxBody<Bytes, ErrorCode>>,
        options: Option<RequestOptions>,
        _fut: Box<dyn Future<Output = std::result::Result<(), ErrorCode>> + Send>,
    ) -> Box<dyn Future<Output = HookResult> + Send> {
        let self_host = SELF_HOST.try_with(|h| h.clone()).ok();
        let is_self = self_host
            .as_deref()
            .map(|h| matches_self(request.uri(), h))
            .unwrap_or(false);

        if !is_self {
            return default_send(request, options);
        }

        Box::new(async move {
            let acc_ptr = ACCESSOR_PTR.with(|c| c.get());
            let svc_ptr = SERVICE_PTR.with(|c| c.get());
            let (Some(acc_ptr), Some(svc_ptr)) = (acc_ptr, svc_ptr) else {
                return Err(ErrorCode::InternalError(Some(
                    "self-invoke accessor slot empty".into(),
                ))
                .into());
            };
            let accessor: &Accessor<ClientState<SystemClock>> = unsafe { &*acc_ptr };
            let service: &Service = unsafe { &*svc_ptr };

            let (p3_req, req_io) = P3Request::from_http(request);
            let handle_result = service.handle(accessor, p3_req).await;

            match handle_result {
                Ok(Ok(resp)) => {
                    let http_resp = accessor
                        .with(|mut access| resp.into_http(access.as_context_mut(), req_io))
                        .map_err(|e| {
                            TrappableError::from(ErrorCode::InternalError(Some(format!("{e:?}"))))
                        })?;
                    let io: Box<dyn Future<Output = std::result::Result<(), ErrorCode>> + Send> =
                        Box::new(async { Ok(()) });
                    Ok((http_resp, io))
                }
                Ok(Err(ec)) => Err(ec.into()),
                Err(e) => Err(ErrorCode::InternalError(Some(format!("{e:?}"))).into()),
            }
        })
    }
}

fn default_send(
    request: http::Request<UnsyncBoxBody<Bytes, ErrorCode>>,
    options: Option<RequestOptions>,
) -> Box<dyn Future<Output = HookResult> + Send> {
    Box::new(async move {
        let (res, io) = default_send_request(request, options).await?;
        let res = res.map(BodyExt::boxed_unsync);
        let io: Box<dyn Future<Output = std::result::Result<(), ErrorCode>> + Send> = Box::new(io);
        Ok((res, io))
    })
}

pub(crate) async fn call_service<C: Clock>(
    accessor: &Accessor<ClientState<C>>,
    service: &Service,
    p3_req: P3Request,
    req_io: impl Future<Output = std::result::Result<(), ErrorCode>> + Send + 'static,
    code_id: &str,
    time_tracker: TimeTracker<C>,
    is_timeout: &Arc<AtomicBool>,
) -> Result<Response> {
    let handle_fut = service.handle(accessor, p3_req);
    let handle_result = measure_cpu_time(time_tracker, handle_fut).await;

    match handle_result {
        Ok(Ok(resp)) => {
            let http_resp = accessor
                .with(|mut access| resp.into_http(access.as_context_mut(), req_io))
                .map_err(|error| {
                    telemetry::wasmtime_error("response_into_http", code_id, &format!("{error:?}"));
                    anyhow!("response into_http failed: {error:?}")
                })?;
            Ok(http_resp.map(|body| {
                body.map_err(|ec| anyhow!("error_code: {ec:?}"))
                    .boxed_unsync()
            }))
        }
        Ok(Err(ec)) => {
            telemetry::proxy_returns_error_code(code_id, &format!("{ec:?}"));
            Err(anyhow!("proxy returned error code: {ec:?}"))
        }
        Err(error) => Err(classify_wasm_error(error, code_id, is_timeout)),
    }
}

pub(crate) fn classify_wasm_error(
    error: wasmtime::Error,
    code_id: &str,
    is_timeout: &Arc<AtomicBool>,
) -> anyhow::Error {
    match error.downcast::<wasmtime::Trap>() {
        Ok(trap) => {
            telemetry::trapped(code_id, &format!("{trap:?}"));
            if is_timeout.load(Ordering::Relaxed) {
                anyhow!("CPU time limit exceeded (trapped: {trap:?})")
            } else {
                anyhow!("trapped: {trap:?}")
            }
        }
        Err(error) => {
            telemetry::canceled_unexpectedly(code_id, &format!("{error:?}"));
            if is_timeout.load(Ordering::Relaxed) {
                anyhow!("CPU time limit exceeded: {error:?}")
            } else {
                anyhow!("canceled unexpectedly: {error:?}")
            }
        }
    }
}