fn0 0.2.13

FaaS platform powered by wasmtime
use crate::Fn0;
use adapt_cache::AdaptCache;
use bytes::Bytes;
use hyper::HeaderMap;
use hyper::http;
use http_body_util::{BodyExt, Full, combinators::UnsyncBoxBody};
use ski::{FetchHandler, FetchHandlerFuture};
use std::string::FromUtf8Error;
use std::sync::{Arc, Mutex};

pub(crate) struct ForteFetchHandler<J>
where
    J: AdaptCache<String, FromUtf8Error> + 'static,
{
    fn0: Arc<Fn0<J>>,
    code_id: String,
    original_headers: HeaderMap,
    collected_cookies: Arc<Mutex<Vec<String>>>,
}

impl<J> ForteFetchHandler<J>
where
    J: AdaptCache<String, FromUtf8Error> + 'static,
{
    pub(crate) fn new(fn0: Arc<Fn0<J>>, code_id: String, original_headers: HeaderMap) -> Self {
        Self {
            fn0,
            code_id,
            original_headers,
            collected_cookies: Arc::new(Mutex::new(Vec::new())),
        }
    }

    pub(crate) fn get_collected_cookies(&self) -> Vec<String> {
        self.collected_cookies.lock().unwrap().clone()
    }
}

impl<J> FetchHandler for ForteFetchHandler<J>
where
    J: AdaptCache<String, FromUtf8Error> + Send + Sync + 'static,
{
    fn handle(&self, req: crate::Request) -> FetchHandlerFuture {
        let path = req.uri().path().to_string();

        if !path.starts_with("/__forte_hook/") {
            return Box::pin(async { None });
        }

        let fn0 = self.fn0.clone();
        let code_id = self.code_id.clone();
        let original_headers = self.original_headers.clone();
        let collected_cookies = self.collected_cookies.clone();

        Box::pin(async move {
            let (mut parts, body) = req.into_parts();

            for (key, value) in &original_headers {
                if key == http::header::HOST {
                    continue;
                }
                if !parts.headers.contains_key(key) {
                    parts.headers.insert(key.clone(), value.clone());
                }
            }

            let path_and_query = parts
                .uri
                .path_and_query()
                .map(|pq| pq.as_str())
                .unwrap_or("/");

            let Some(host) = original_headers
                .get(http::header::HOST)
                .and_then(|v| v.to_str().ok())
            else {
                let body: UnsyncBoxBody<Bytes, anyhow::Error> =
                    Full::new(Bytes::from("Missing Host header in original request"))
                        .map_err(|e| anyhow::anyhow!("{e}"))
                        .boxed_unsync();
                return Some(
                    hyper::Response::builder()
                        .status(400)
                        .body(body)
                        .expect("failed to build response"),
                );
            };

            let new_uri = format!("http://{}{}", host, path_and_query);
            let Ok(uri) = new_uri.parse() else {
                let body: UnsyncBoxBody<Bytes, anyhow::Error> =
                    Full::new(Bytes::from("Invalid URI"))
                        .map_err(|e| anyhow::anyhow!("{e}"))
                        .boxed_unsync();
                return Some(
                    hyper::Response::builder()
                        .status(400)
                        .body(body)
                        .expect("failed to build response"),
                );
            };
            parts.uri = uri;

            let req = hyper::Request::from_parts(parts, body);
            let response = fn0.run_forte_backend(&code_id, "", req, None).await;

            match response {
                Ok(resp) => {
                    for value in resp.headers().get_all(http::header::SET_COOKIE) {
                        if let Ok(s) = value.to_str() {
                            collected_cookies.lock().unwrap().push(s.to_string());
                        }
                    }
                    Some(resp)
                }
                Err(e) => {
                    eprintln!("[ForteFetchHandler] Error calling hook: {:?}", e);
                    let body: UnsyncBoxBody<Bytes, anyhow::Error> =
                        Full::new(Bytes::from(format!("Hook error: {}", e)))
                            .map_err(|e| anyhow::anyhow!("{e}"))
                            .boxed_unsync();
                    Some(
                        hyper::Response::builder()
                            .status(500)
                            .header("content-type", "text/plain")
                            .body(body)
                            .unwrap(),
                    )
                }
            }
        })
    }
}