fn0 0.2.29

FaaS platform powered by wasmtime
use anyhow::Result;
use base64::Engine;
use bytes::Bytes;
use http_body_util::BodyExt;
use http_body_util::combinators::UnsyncBoxBody;
use hyper::http::uri::Scheme;
use oci_rust_sdk::auth::{RequestSigner, SimpleAuthProvider, SimpleAuthProviderRequiredFields};
use serde_json::Value;
use std::sync::Arc;
use wasmtime_wasi_http::p3::bindings::http::types::ErrorCode;

#[derive(Clone, Copy)]
enum Op {
    Decrypt,
    Encrypt,
    GenerateDek,
}

impl Op {
    fn from_path(path: &str) -> Option<Self> {
        match path {
            "/decrypt" => Some(Op::Decrypt),
            "/encrypt" => Some(Op::Encrypt),
            "/generate-dek" => Some(Op::GenerateDek),
            _ => None,
        }
    }

    fn oci_path(&self) -> &'static str {
        match self {
            Op::Decrypt => "/20180608/decrypt",
            Op::Encrypt => "/20180608/encrypt",
            Op::GenerateDek => "/20180608/generateDataEncryptionKey",
        }
    }
}

#[derive(Clone)]
pub struct VaultHijack {
    pub placeholder_host: String,
    crypto_host: String,
    allowed_project_id: String,
    key_ocid: String,
    signer: Arc<RequestSigner>,
}

impl VaultHijack {
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        placeholder_host: String,
        crypto_endpoint: String,
        allowed_project_id: String,
        key_ocid: String,
        tenancy: String,
        user: String,
        fingerprint: String,
        private_key_pem: String,
    ) -> Result<Self> {
        let crypto_host = host_from_endpoint(&crypto_endpoint)?;

        let provider: Arc<SimpleAuthProvider> = Arc::new(
            SimpleAuthProvider::builder(SimpleAuthProviderRequiredFields {
                tenancy,
                user,
                fingerprint,
                private_key: private_key_pem,
            })
            .build(),
        );
        let signer = Arc::new(
            RequestSigner::new(provider as Arc<dyn oci_rust_sdk::auth::AuthProvider>)
                .map_err(|e| anyhow::anyhow!("vault signer init failed: {e:?}"))?,
        );

        Ok(Self {
            placeholder_host,
            crypto_host,
            allowed_project_id,
            key_ocid,
            signer,
        })
    }

    pub fn from_env() -> Result<Self> {
        let crypto_endpoint = std::env::var("FN0_VAULT_CRYPTO_ENDPOINT")
            .map_err(|_| anyhow::anyhow!("FN0_VAULT_CRYPTO_ENDPOINT is required"))?;
        let placeholder_host = std::env::var("FN0_VAULT_PLACEHOLDER_HOST")
            .unwrap_or_else(|_| "fn0-vault.fn0.dev".to_string());
        let allowed_project_id = std::env::var("FN0_VAULT_ALLOWED_SUBDOMAIN")
            .map_err(|_| anyhow::anyhow!("FN0_VAULT_ALLOWED_SUBDOMAIN is required"))?;
        let key_ocid = std::env::var("FN0_VAULT_KEY_OCID")
            .map_err(|_| anyhow::anyhow!("FN0_VAULT_KEY_OCID is required"))?;
        let tenancy = std::env::var("FN0_VAULT_OCI_TENANCY_ID")
            .map_err(|_| anyhow::anyhow!("FN0_VAULT_OCI_TENANCY_ID is required"))?;
        let user = std::env::var("FN0_VAULT_OCI_USER_ID")
            .map_err(|_| anyhow::anyhow!("FN0_VAULT_OCI_USER_ID is required"))?;
        let fingerprint = std::env::var("FN0_VAULT_OCI_FINGERPRINT")
            .map_err(|_| anyhow::anyhow!("FN0_VAULT_OCI_FINGERPRINT is required"))?;
        let private_key_b64 = std::env::var("FN0_VAULT_OCI_PRIVATE_KEY_BASE64")
            .map_err(|_| anyhow::anyhow!("FN0_VAULT_OCI_PRIVATE_KEY_BASE64 is required"))?;
        let private_key_pem = base64::engine::general_purpose::STANDARD
            .decode(private_key_b64.as_bytes())
            .map_err(|e| anyhow::anyhow!("vault private key base64 decode: {e}"))
            .and_then(|b| {
                String::from_utf8(b).map_err(|e| anyhow::anyhow!("vault private key utf8: {e}"))
            })?;

        Self::new(
            placeholder_host,
            crypto_endpoint,
            allowed_project_id,
            key_ocid,
            tenancy,
            user,
            fingerprint,
            private_key_pem,
        )
    }

    pub fn placeholder_url(&self) -> String {
        format!("http://{}", self.placeholder_host)
    }

    pub(crate) fn matches(&self, uri: &hyper::Uri) -> bool {
        uri.host() == Some(self.placeholder_host.as_str())
    }

    pub(crate) fn build_signed_request(
        &self,
        project_id: &str,
        method: &str,
        path: &str,
        body_bytes: &[u8],
    ) -> Result<hyper::Request<UnsyncBoxBody<Bytes, ErrorCode>>, ErrorCode> {
        if project_id != self.allowed_project_id {
            return Err(ErrorCode::HttpRequestDenied);
        }
        if method != hyper::Method::POST.as_str() {
            return Err(ErrorCode::HttpRequestMethodInvalid);
        }
        let op = Op::from_path(path).ok_or(ErrorCode::HttpRequestDenied)?;

        let oci_body = build_oci_body(op, &self.key_ocid, body_bytes)
            .map_err(|e| ErrorCode::InternalError(Some(format!("vault body build: {e}"))))?;

        let oci_path = op.oci_path();
        let url_str = format!("https://{}{}", self.crypto_host, oci_path);
        let url = url::Url::parse(&url_str)
            .map_err(|e| ErrorCode::InternalError(Some(format!("vault url parse: {e}"))))?;

        let mut headers = reqwest::header::HeaderMap::new();
        self.signer
            .sign_request(method, &url, &mut headers, Some(&oci_body))
            .map_err(|e| ErrorCode::InternalError(Some(format!("vault sign: {e:?}"))))?;

        let uri = hyper::Uri::builder()
            .scheme(Scheme::HTTPS)
            .authority(self.crypto_host.as_str())
            .path_and_query(oci_path)
            .build()
            .map_err(|_| ErrorCode::HttpRequestUriInvalid)?;

        let mut builder = hyper::Request::builder().method(method).uri(uri);
        for (name, value) in headers.iter() {
            builder = builder.header(name.as_str(), value);
        }

        let body: UnsyncBoxBody<Bytes, ErrorCode> =
            http_body_util::Full::new(Bytes::from(oci_body))
                .map_err(|never: std::convert::Infallible| match never {})
                .boxed_unsync();

        builder
            .body(body)
            .map_err(|e| ErrorCode::InternalError(Some(format!("vault request build: {e}"))))
    }
}

fn build_oci_body(op: Op, key_ocid: &str, wasm_body: &[u8]) -> Result<Vec<u8>> {
    let mut value: Value = if wasm_body.is_empty() {
        Value::Object(serde_json::Map::new())
    } else {
        serde_json::from_slice(wasm_body).map_err(|e| anyhow::anyhow!("vault body parse: {e}"))?
    };

    let obj = value
        .as_object_mut()
        .ok_or_else(|| anyhow::anyhow!("vault body must be a JSON object"))?;

    obj.insert("keyId".into(), Value::String(key_ocid.to_string()));

    if let Op::GenerateDek = op {
        obj.entry("keyShape")
            .or_insert_with(|| serde_json::json!({ "algorithm": "AES", "length": 32 }));
        obj.entry("includePlaintextKey")
            .or_insert(Value::Bool(false));
    }

    Ok(serde_json::to_vec(&value)?)
}

fn host_from_endpoint(endpoint: &str) -> Result<String> {
    let url =
        url::Url::parse(endpoint).map_err(|e| anyhow::anyhow!("vault endpoint parse: {e}"))?;
    let host = url
        .host_str()
        .ok_or_else(|| anyhow::anyhow!("vault endpoint missing host: {endpoint}"))?;
    if let Some(port) = url.port() {
        Ok(format!("{host}:{port}"))
    } else {
        Ok(host.to_string())
    }
}