gestalt-sdk 0.0.1-alpha.12

Rust SDK scaffolding and generated protocol bindings for Gestalt executable providers
Documentation
use std::env;
#[cfg(unix)]
use std::future::Future;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;

#[cfg(unix)]
use tokio::net::UnixListener;
#[cfg(unix)]
use tokio::signal;
#[cfg(unix)]
use tokio::time::sleep;
#[cfg(unix)]
use tokio_stream::wrappers::UnixListenerStream;
#[cfg(unix)]
use tonic::transport::Server;

use crate::catalog::write_catalog;
use crate::env::{
    ENV_PROVIDER_NAME, ENV_PROVIDER_PARENT_PID, ENV_PROVIDER_SOCKET, ENV_WRITE_CATALOG,
};
use crate::error::{Error, Result};
#[cfg(unix)]
use crate::generated::v1::auth_provider_server::AuthProviderServer;
#[cfg(unix)]
use crate::generated::v1::cache_server::CacheServer;
#[cfg(unix)]
use crate::generated::v1::integration_provider_server::IntegrationProviderServer;
#[cfg(unix)]
use crate::generated::v1::provider_lifecycle_server::ProviderLifecycleServer;
#[cfg(unix)]
use crate::generated::v1::s3_server::S3Server;
#[cfg(unix)]
use crate::generated::v1::secrets_provider_server::SecretsProviderServer;
use crate::provider_server::ProviderServer;
use crate::{AuthProvider, CacheProvider, S3Provider, SecretsProvider};
use crate::{Provider, Router};
#[cfg(unix)]
use crate::{
    auth_server::AuthServer, cache_server::CacheRpcServer, runtime_server::RuntimeServer,
    secrets_server::SecretsServer,
};

fn build_runtime_and_block_on<F, Fut>(f: F) -> Result<()>
where
    F: FnOnce() -> Fut,
    Fut: std::future::Future<Output = Result<()>>,
{
    let runtime = tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .map_err(|error| Error::internal(error.to_string()))?;
    runtime.block_on(f())
}

pub fn run_provider<P: Provider>(provider: Arc<P>, router: Router<P>) -> Result<()> {
    build_runtime_and_block_on(|| serve_provider(provider, router))
}

pub fn run_auth_provider<P: AuthProvider>(provider: Arc<P>) -> Result<()> {
    build_runtime_and_block_on(|| serve_auth_provider(provider))
}

pub fn run_cache_provider<P: CacheProvider>(provider: Arc<P>) -> Result<()> {
    build_runtime_and_block_on(|| serve_cache_provider(provider))
}

pub fn run_secrets_provider<P: SecretsProvider>(provider: Arc<P>) -> Result<()> {
    build_runtime_and_block_on(|| serve_secrets_provider(provider))
}

pub fn run_s3_provider<P: S3Provider>(provider: Arc<P>) -> Result<()> {
    build_runtime_and_block_on(|| serve_s3_provider(provider))
}

pub fn write_catalog_path<P>(router: &Router<P>, path: impl AsRef<Path>) -> Result<()> {
    write_catalog(router.catalog(), path)
}

pub fn maybe_write_catalog<P>(router: &Router<P>) -> Result<bool> {
    let Some(path) = env::var_os(ENV_WRITE_CATALOG) else {
        return Ok(false);
    };

    let catalog = if let Ok(name) = env::var(ENV_PROVIDER_NAME) {
        router.catalog().clone().with_name(name)
    } else {
        router.catalog().clone()
    };

    write_catalog(&catalog, PathBuf::from(path))?;
    Ok(true)
}

#[cfg(unix)]
pub async fn serve_provider<P>(provider: Arc<P>, router: Router<P>) -> Result<()>
where
    P: Provider,
{
    if maybe_write_catalog(&router)? {
        return Ok(());
    }
    let server = ProviderServer::new(Arc::clone(&provider), router);
    serve_unix_provider(
        provider,
        move |incoming, provider| {
            Server::builder()
                .add_service(ProviderLifecycleServer::new(RuntimeServer::for_provider(
                    Arc::clone(&provider),
                )))
                .add_service(IntegrationProviderServer::new(server))
                .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
        },
        |provider| async move { provider.close().await },
    )
    .await
}

#[cfg(unix)]
pub async fn serve_auth_provider<P>(provider: Arc<P>) -> Result<()>
where
    P: AuthProvider,
{
    serve_unix_provider(
        provider,
        move |incoming, provider| {
            Server::builder()
                .add_service(ProviderLifecycleServer::new(RuntimeServer::for_auth(
                    Arc::clone(&provider),
                )))
                .add_service(AuthProviderServer::new(AuthServer::new(Arc::clone(
                    &provider,
                ))))
                .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
        },
        |provider| async move { provider.close().await },
    )
    .await
}

#[cfg(unix)]
pub async fn serve_cache_provider<P>(provider: Arc<P>) -> Result<()>
where
    P: CacheProvider,
{
    serve_unix_provider(
        provider,
        move |incoming, provider| {
            Server::builder()
                .add_service(ProviderLifecycleServer::new(RuntimeServer::for_cache(
                    Arc::clone(&provider),
                )))
                .add_service(CacheServer::new(CacheRpcServer::new(Arc::clone(&provider))))
                .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
        },
        |provider| async move { provider.close().await },
    )
    .await
}

#[cfg(unix)]
pub async fn serve_secrets_provider<P>(provider: Arc<P>) -> Result<()>
where
    P: SecretsProvider,
{
    serve_unix_provider(
        provider,
        move |incoming, provider| {
            Server::builder()
                .add_service(ProviderLifecycleServer::new(RuntimeServer::for_secrets(
                    Arc::clone(&provider),
                )))
                .add_service(SecretsProviderServer::new(SecretsServer::new(Arc::clone(
                    &provider,
                ))))
                .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
        },
        |provider| async move { provider.close().await },
    )
    .await
}

#[cfg(unix)]
pub async fn serve_s3_provider<P>(provider: Arc<P>) -> Result<()>
where
    P: S3Provider,
{
    serve_unix_provider(
        provider,
        move |incoming, provider| {
            Server::builder()
                .add_service(ProviderLifecycleServer::new(RuntimeServer::for_s3(
                    Arc::clone(&provider),
                )))
                .add_service(S3Server::new(Arc::clone(&provider)))
                .serve_with_incoming_shutdown(incoming, shutdown_signal(parent_pid()))
        },
        |provider| async move { provider.close().await },
    )
    .await
}

#[cfg(not(unix))]
pub async fn serve_provider<P>(_provider: Arc<P>, router: Router<P>) -> Result<()>
where
    P: Provider,
{
    if maybe_write_catalog(&router)? {
        return Ok(());
    }
    Err(Error::internal(
        "unix sockets are unsupported on this platform",
    ))
}

#[cfg(not(unix))]
pub async fn serve_auth_provider<P>(_provider: Arc<P>) -> Result<()>
where
    P: AuthProvider,
{
    Err(Error::internal(
        "unix sockets are unsupported on this platform",
    ))
}

#[cfg(not(unix))]
pub async fn serve_cache_provider<P>(_provider: Arc<P>) -> Result<()>
where
    P: CacheProvider,
{
    Err(Error::internal(
        "unix sockets are unsupported on this platform",
    ))
}

#[cfg(not(unix))]
pub async fn serve_secrets_provider<P>(_provider: Arc<P>) -> Result<()>
where
    P: SecretsProvider,
{
    Err(Error::internal(
        "unix sockets are unsupported on this platform",
    ))
}

#[cfg(not(unix))]
pub async fn serve_s3_provider<P>(_provider: Arc<P>) -> Result<()>
where
    P: S3Provider,
{
    Err(Error::internal(
        "unix sockets are unsupported on this platform",
    ))
}

#[cfg(unix)]
async fn shutdown_signal(parent_pid: Option<u32>) {
    let ctrl_c = async {
        let _ = signal::ctrl_c().await;
    };

    tokio::pin!(ctrl_c);

    if let Some(parent_pid) = parent_pid {
        tokio::select! {
            _ = &mut ctrl_c => {}
            _ = watch_parent(parent_pid) => {}
        }
        return;
    }

    ctrl_c.await;
}

#[cfg(unix)]
async fn serve_unix_provider<P, F, S, C, CF>(provider: Arc<P>, serve: F, close: C) -> Result<()>
where
    P: Send + Sync,
    F: FnOnce(UnixListenerStream, Arc<P>) -> S,
    S: Future<Output = std::result::Result<(), tonic::transport::Error>>,
    C: FnOnce(Arc<P>) -> CF,
    CF: Future<Output = Result<()>>,
{
    let socket = env::var_os(ENV_PROVIDER_SOCKET)
        .ok_or_else(|| Error::internal(format!("{ENV_PROVIDER_SOCKET} is required")))?;
    let socket = PathBuf::from(socket);
    if socket.exists() {
        std::fs::remove_file(&socket)?;
    }
    if let Some(parent) = socket.parent()
        && !parent.as_os_str().is_empty()
    {
        std::fs::create_dir_all(parent)?;
    }

    let listener = UnixListener::bind(&socket)?;
    let incoming = UnixListenerStream::new(listener);
    let serve_result = serve(incoming, Arc::clone(&provider))
        .await
        .map_err(Error::from);

    let close_result = close(provider).await;
    let _ = remove_socket(&socket);

    serve_result?;
    close_result
}

#[cfg(unix)]
fn parent_pid() -> Option<u32> {
    env::var(ENV_PROVIDER_PARENT_PID)
        .ok()
        .and_then(|value| value.parse::<u32>().ok())
        .filter(|pid| *pid > 0)
}

#[cfg(unix)]
async fn watch_parent(parent_pid: u32) {
    loop {
        if current_parent_pid() != parent_pid {
            break;
        }
        sleep(Duration::from_millis(500)).await;
    }
}

#[cfg(unix)]
fn current_parent_pid() -> u32 {
    unsafe { libc::getppid() as u32 }
}

#[cfg(unix)]
fn remove_socket(path: &Path) -> std::io::Result<()> {
    match std::fs::remove_file(path) {
        Ok(()) => Ok(()),
        Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(()),
        Err(error) => Err(error),
    }
}