fn0-worker 0.1.0

Worker binary for the fn0 FaaS platform
use crate::WorkerFn0;
use crate::bundle::BundleFetcher;
use color_eyre::eyre::{Result, eyre};
use host_hq_protocol::{HostToHq, HqToHostDatagram, HqToHostReliable};
use quinn::Endpoint;
use rcgen::{CertificateParams, KeyPair};
use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};

use super::read_pem_env;

pub async fn run_quic_server(
    port: u16,
    deployment_id: Arc<AtomicU64>,
    instance_count: Arc<AtomicU64>,
    graceful_shutdown: Arc<AtomicBool>,
    _fn0: Arc<WorkerFn0>,
    bundle_fetcher: Arc<BundleFetcher>,
) -> Result<()> {
    let ca_cert_pem = match read_pem_env("CA_CERT_PEM") {
        Some(v) => v,
        None => {
            tracing::info!("CA_CERT_PEM not set, QUIC server disabled");
            std::future::pending::<()>().await;
            return Ok(());
        }
    };
    let ca_key_pem = read_pem_env("CA_KEY_PEM")
        .ok_or_else(|| eyre!("CA_KEY_PEM env var not set"))?;

    let ca_key_pair = KeyPair::from_pem(&ca_key_pem)?;
    let ca_params = CertificateParams::from_ca_cert_pem(&ca_cert_pem)?;
    let ca_cert = ca_params.self_signed(&ca_key_pair)?;

    let mut worker_params = CertificateParams::new(vec!["host.fn0".to_string()])?;
    worker_params.is_ca = rcgen::IsCa::NoCa;
    let worker_key_pair = KeyPair::generate()?;
    let worker_cert = worker_params.signed_by(&worker_key_pair, &ca_cert, &ca_key_pair)?;

    let cert_der = CertificateDer::from(worker_cert.der().to_vec());
    let key_der = PrivatePkcs8KeyDer::from(worker_key_pair.serialize_der());

    let server_config = rustls::ServerConfig::builder()
        .with_no_client_auth()
        .with_single_cert(vec![cert_der], key_der.into())?;

    let quinn_config = quinn::ServerConfig::with_crypto(Arc::new(
        quinn::crypto::rustls::QuicServerConfig::try_from(server_config)?,
    ));

    let addr = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 0], port));
    let endpoint = Endpoint::server(quinn_config, addr)?;
    tracing::info!(%addr, "QUIC server listening");

    while let Some(incoming) = endpoint.accept().await {
        let deployment_id = deployment_id.clone();
        let instance_count = instance_count.clone();
        let graceful_shutdown = graceful_shutdown.clone();
        let bundle_fetcher = bundle_fetcher.clone();

        tokio::spawn(async move {
            match incoming.await {
                Ok(connection) => {
                    tracing::info!(remote = %connection.remote_address(), "HQ connected");
                    handle_connection(connection, deployment_id, instance_count, graceful_shutdown, bundle_fetcher).await;
                }
                Err(err) => {
                    tracing::warn!(%err, "Failed to accept QUIC connection");
                }
            }
        });
    }

    Ok(())
}

async fn handle_connection(
    connection: quinn::Connection,
    deployment_id: Arc<AtomicU64>,
    instance_count: Arc<AtomicU64>,
    graceful_shutdown: Arc<AtomicBool>,
    bundle_fetcher: Arc<BundleFetcher>,
) {
    let datagram_handle = tokio::spawn({
        let connection = connection.clone();
        let deployment_id = deployment_id.clone();
        let instance_count = instance_count.clone();
        async move {
            loop {
                match connection.read_datagram().await {
                    Ok(bytes) => {
                        let Ok(msg) = HqToHostDatagram::from_bytes(bytes) else {
                            tracing::warn!("Failed to parse HQ datagram");
                            continue;
                        };

                        match msg {
                            HqToHostDatagram::AdvertiseLatestDeploymentId { .. } => {
                                let current = deployment_id.load(Ordering::Relaxed);
                                let timestamp = SystemTime::now()
                                    .duration_since(UNIX_EPOCH)
                                    .unwrap()
                                    .as_secs();

                                let pong = HostToHq::NotifyHostStatus {
                                    timestamp,
                                    deployment_id: current,
                                    instances: instance_count.load(Ordering::Relaxed),
                                };
                                if let Ok(bytes) = pong.to_bytes() {
                                    let _ = connection.send_datagram(bytes);
                                }
                            }
                        }
                    }
                    Err(err) => {
                        tracing::warn!(%err, "Datagram read error, HQ disconnected");
                        break;
                    }
                }
            }
        }
    });

    let reliable_handle = tokio::spawn({
        let connection = connection.clone();
        let deployment_id = deployment_id.clone();
        let graceful_shutdown = graceful_shutdown.clone();
        let bundle_fetcher = bundle_fetcher.clone();
        async move {
            loop {
                match connection.accept_uni().await {
                    Ok(mut recv) => {
                        let data = match recv.read_to_end(64 * 1024).await {
                            Ok(data) => data,
                            Err(err) => {
                                tracing::warn!(%err, "Failed to read reliable message");
                                continue;
                            }
                        };

                        let Ok(msg) = HqToHostReliable::from_bytes(data.into()) else {
                            tracing::warn!("Failed to parse HQ reliable message");
                            continue;
                        };

                        match msg {
                            HqToHostReliable::DeploymentUpdates {
                                deployment_id: new_deployment_id,
                                codes,
                            } => {
                                tracing::info!(
                                    new_deployment_id,
                                    codes = codes.len(),
                                    "Received deployment update"
                                );

                                for code in &codes {
                                    match code {
                                        host_hq_protocol::CodeDeployment::Deploy {
                                            subdomain,
                                            code_id,
                                            code_version,
                                        } => {
                                            if let Err(err) = bundle_fetcher
                                                .fetch_and_register(subdomain, *code_id, *code_version)
                                                .await
                                            {
                                                tracing::error!(%err, %subdomain, "Failed to fetch bundle");
                                            }
                                        }
                                        host_hq_protocol::CodeDeployment::Undeploy { subdomain } => {
                                            bundle_fetcher.unregister(subdomain).await;
                                        }
                                    }
                                }

                                deployment_id.store(new_deployment_id, Ordering::Relaxed);
                            }
                            HqToHostReliable::GracefulShutdown => {
                                tracing::info!("Received graceful shutdown from HQ");
                                graceful_shutdown.store(true, Ordering::Relaxed);
                            }
                        }
                    }
                    Err(err) => {
                        tracing::warn!(%err, "Uni stream accept error, HQ disconnected");
                        break;
                    }
                }
            }
        }
    });

    tokio::select! {
        _ = datagram_handle => {},
        _ = reliable_handle => {},
    }
}