fn0-worker 0.3.33

Worker binary for the fn0 FaaS platform
use crate::env_yaml;
use crate::vault_client::VaultClient;
use async_singleflight::Group;
use fn0::cache::{Bundle, BundleCache, Error};
use fn0::execute::ClientState;
use fn0::measure_cpu_time::SystemClock;
use fn0::wasmtime::Engine;
use fn0::wasmtime::component::Linker;
use opendal::{ErrorKind as OpendalErrorKind, Operator};
use serde::Deserialize;
use std::collections::{HashMap, VecDeque};
use std::io::Read;
use std::sync::Arc;
use tokio::sync::Mutex;

#[derive(Debug, Deserialize)]
#[serde(tag = "kind", rename_all = "lowercase")]
enum Manifest {
    Wasm,
    WasmJs,
}

struct LruEntry {
    project_id: String,
    code_version: u64,
    bundle: Arc<Bundle>,
    size_bytes: usize,
}

struct Inner {
    registry: HashMap<String, u64>,
    domain_to_project_id: HashMap<String, String>,
    lru: VecDeque<LruEntry>,
    current_bytes: usize,
}

#[derive(Clone)]
pub struct S3BundleCache {
    engine: Engine,
    linker: Linker<ClientState<SystemClock>>,
    operator: Operator,
    vault: Arc<VaultClient>,
    inner: Arc<Mutex<Inner>>,
    cache_size_bytes: usize,
    singleflight: Arc<Group<String, Arc<Bundle>, Error>>,
}

impl S3BundleCache {
    pub fn new(
        engine: Engine,
        linker: Linker<ClientState<SystemClock>>,
        operator: Operator,
        vault: Arc<VaultClient>,
        cache_size_bytes: usize,
    ) -> Self {
        Self {
            engine,
            linker,
            operator,
            vault,
            inner: Arc::new(Mutex::new(Inner {
                registry: HashMap::new(),
                domain_to_project_id: HashMap::new(),
                lru: VecDeque::new(),
                current_bytes: 0,
            })),
            cache_size_bytes,
            singleflight: Arc::new(Group::new()),
        }
    }

    #[tracing::instrument(skip_all, fields(project_id = %project_id, code_version))]
    pub async fn register(&self, project_id: &str, code_version: u64) {
        let mut inner = self.inner.lock().await;
        let prev = inner
            .registry
            .insert(project_id.to_string(), code_version);
        if prev != Some(code_version)
            && let Some(pos) = inner.lru.iter().position(|e| e.project_id == project_id)
        {
            let removed = inner.lru.remove(pos).unwrap();
            inner.current_bytes = inner.current_bytes.saturating_sub(removed.size_bytes);
        }
    }

    pub async fn unregister(&self, project_id: &str) {
        let mut inner = self.inner.lock().await;
        inner.registry.remove(project_id);
        if let Some(pos) = inner.lru.iter().position(|e| e.project_id == project_id) {
            let removed = inner.lru.remove(pos).unwrap();
            inner.current_bytes = inner.current_bytes.saturating_sub(removed.size_bytes);
        }
    }

    pub async fn register_domain(&self, domain: &str, project_id: &str) {
        let mut inner = self.inner.lock().await;
        inner
            .domain_to_project_id
            .insert(domain.to_string(), project_id.to_string());
    }

    pub async fn unregister_domain(&self, domain: &str) {
        let mut inner = self.inner.lock().await;
        inner.domain_to_project_id.remove(domain);
    }

    pub async fn resolve_domain(&self, domain: &str) -> Option<String> {
        let inner = self.inner.lock().await;
        inner.domain_to_project_id.get(domain).cloned()
    }

    #[tracing::instrument(skip_all, fields(project_id = %project_id))]
    async fn get_impl(&self, project_id: &str) -> Result<Arc<Bundle>, Error> {
        let code_version = {
            let mut inner = self.inner.lock().await;
            let Some(&code_version) = inner.registry.get(project_id) else {
                return Err(Error::NotFound);
            };
            if let Some(pos) = inner
                .lru
                .iter()
                .position(|e| e.project_id == project_id && e.code_version == code_version)
            {
                let entry = inner.lru.remove(pos).unwrap();
                let bundle = entry.bundle.clone();
                inner.lru.push_front(entry);
                return Ok(bundle);
            }
            code_version
        };

        let (bundle, size) = self.fetch_and_build(project_id, code_version).await?;

        let mut inner = self.inner.lock().await;
        if let Some(pos) = inner.lru.iter().position(|e| e.project_id == project_id) {
            let removed = inner.lru.remove(pos).unwrap();
            inner.current_bytes = inner.current_bytes.saturating_sub(removed.size_bytes);
        }
        inner.lru.push_front(LruEntry {
            project_id: project_id.to_string(),
            code_version,
            bundle: bundle.clone(),
            size_bytes: size,
        });
        inner.current_bytes += size;

        while inner.current_bytes > self.cache_size_bytes && inner.lru.len() > 1 {
            if let Some(evicted) = inner.lru.pop_back() {
                inner.current_bytes = inner.current_bytes.saturating_sub(evicted.size_bytes);
            } else {
                break;
            }
        }

        Ok(bundle)
    }

    #[tracing::instrument(skip_all, fields(project_id = %project_id, code_version))]
    async fn fetch_and_build(
        &self,
        project_id: &str,
        code_version: u64,
    ) -> Result<(Arc<Bundle>, usize), Error> {
        let key = format!(
            "compiled/{fn0_wasmtime_version}/{project_id}/{code_version}.tar.zst",
            fn0_wasmtime_version = fn0::FN0_WASMTIME_VERSION,
        );
        let compressed = match self.operator.read(&key).await {
            Ok(buf) => buf.to_vec(),
            Err(e) if e.kind() == OpendalErrorKind::NotFound => return Err(Error::NotFound),
            Err(e) => return Err(Error::Storage(anyhow::anyhow!(e))),
        };

        let tar_bytes = zstd::decode_all(compressed.as_slice())
            .map_err(|e| Error::Decode(anyhow::anyhow!(e)))?;

        let mut manifest: Option<Manifest> = None;
        let mut wasm_bytes: Option<Vec<u8>> = None;
        let mut js_bytes: Option<Vec<u8>> = None;
        let mut env_yaml_bytes: Option<Vec<u8>> = None;

        let mut archive = tar::Archive::new(tar_bytes.as_slice());
        let entries = archive
            .entries()
            .map_err(|e| Error::Decode(anyhow::anyhow!(e)))?;
        for entry in entries {
            let mut entry = entry.map_err(|e| Error::Decode(anyhow::anyhow!(e)))?;
            let path = entry
                .path()
                .map_err(|e| Error::Decode(anyhow::anyhow!(e)))?
                .to_path_buf();
            let path_str = path.to_string_lossy().to_string();

            let mut buf = Vec::new();
            entry
                .read_to_end(&mut buf)
                .map_err(|e| Error::Decode(anyhow::anyhow!(e)))?;

            match path_str.as_str() {
                "manifest.json" => {
                    manifest = Some(
                        serde_json::from_slice(&buf)
                            .map_err(|e| Error::Decode(anyhow::anyhow!(e)))?,
                    );
                }
                "wasm.cwasm.zst" => {
                    let decompressed = zstd::decode_all(buf.as_slice())
                        .map_err(|e| Error::Decode(anyhow::anyhow!(e)))?;
                    wasm_bytes = Some(decompressed);
                }
                "entry.js" => {
                    js_bytes = Some(buf);
                }
                "env.yaml" => {
                    env_yaml_bytes = Some(buf);
                }
                _ => {}
            }
        }

        let manifest =
            manifest.ok_or_else(|| Error::Decode(anyhow::anyhow!("missing manifest.json")))?;
        let cwasm =
            wasm_bytes.ok_or_else(|| Error::Decode(anyhow::anyhow!("missing wasm.cwasm.zst")))?;

        let service_pre =
            fn0::build_service_pre(&self.engine, &self.linker, &cwasm).map_err(Error::Compile)?;
        let cwasm_size = cwasm.len();

        let js = match manifest {
            Manifest::Wasm => None,
            Manifest::WasmJs => {
                let js = js_bytes.ok_or_else(|| {
                    Error::Decode(anyhow::anyhow!("wasmjs bundle missing entry.js"))
                })?;
                let js_str =
                    String::from_utf8(js).map_err(|e| Error::Decode(anyhow::anyhow!(e)))?;
                Some(js_str)
            }
        };
        let js_size = js.as_ref().map(|s| s.len()).unwrap_or(0);

        let env_vars = match env_yaml_bytes {
            Some(bytes) => env_yaml::load(&bytes, &self.vault)
                .await
                .map_err(|e| Error::Decode(anyhow::anyhow!("env.yaml load: {e}")))?,
            None => Vec::new(),
        };

        let size = cwasm_size + js_size;
        Ok((
            Arc::new(Bundle {
                service_pre,
                js,
                env_vars,
            }),
            size,
        ))
    }
}

impl BundleCache for S3BundleCache {
    async fn get(&self, project_id: &str) -> Result<Arc<Bundle>, Error> {
        let key = project_id.to_string();
        let provider = self.clone();
        let key_for_task = key.clone();
        let singleflight = provider.singleflight.clone();
        singleflight
            .work(&key, async move { provider.get_impl(&key_for_task).await })
            .await
            .map_err(|opt| opt.unwrap_or(Error::SingleflightLeaderFailed))
    }

    async fn invalidate(&self, project_id: &str) {
        let project_id = project_id.to_string();
        let mut inner = self.inner.lock().await;
        if let Some(pos) = inner.lru.iter().position(|e| e.project_id == project_id) {
            let removed = inner.lru.remove(pos).unwrap();
            inner.current_bytes = inner.current_bytes.saturating_sub(removed.size_bytes);
        }
    }
}