coil-runtime 0.1.1

HTTP runtime and request handling for the Coil framework.
Documentation
use std::sync::OnceLock;

use coil_data::{DataRuntime, PostgresDataClient};
use sqlx::Row;

use super::*;

#[derive(Debug, Clone)]
pub(super) struct SharedMetadataAuditStore {
    runtime: DataRuntime,
    client: OnceLock<Result<PostgresDataClient, String>>,
    schema: String,
    initialized: OnceLock<Result<(), String>>,
}

impl SharedMetadataAuditStore {
    pub(super) fn open(runtime: DataRuntime) -> Self {
        let schema = runtime.schema.clone();
        Self {
            runtime,
            client: OnceLock::new(),
            schema,
            initialized: OnceLock::new(),
        }
    }

    pub(super) fn location_label(&self) -> String {
        format!("shared-postgres:{}.metadata_audit_entries", self.schema)
    }

    pub(super) fn insert(&self, record: &MetadataAuditRecord) -> Result<(), String> {
        self.ensure_initialized()?;
        let client = self.client()?.clone();
        let table = self.qualified_table();
        let record = record.clone();
        run_blocking(async move {
            sqlx::query(&format!(
                "INSERT INTO {} (recorded_at_unix_seconds, app_id, trace_id, request_id, principal_kind, principal_id, kind) VALUES ($1, $2, $3, $4, $5, $6, $7)",
                table
            ))
            .bind(record.recorded_at_unix_seconds)
            .bind(&record.app_id)
            .bind(&record.trace_id)
            .bind(&record.request_id)
            .bind(&record.principal_kind)
            .bind(&record.principal_id)
            .bind(&record.kind)
            .execute(&client.pool)
            .await
            .map_err(|error| format!("failed to write shared metadata audit entry: {error}"))?;
            Ok(())
        })
    }

    pub(super) fn count(&self) -> Result<usize, String> {
        self.ensure_initialized()?;
        let client = self.client()?.clone();
        let table = self.qualified_table();
        run_blocking(async move {
            let count: i64 = sqlx::query_scalar(&format!("SELECT COUNT(*) FROM {}", table))
                .fetch_one(&client.pool)
                .await
                .map_err(|error| {
                    format!("failed to count shared metadata audit entries: {error}")
                })?;
            usize::try_from(count)
                .map_err(|_| "shared metadata audit entry count overflowed usize".to_string())
        })
    }

    pub(super) fn recent(&self, limit: usize) -> Result<Vec<MetadataAuditRecord>, String> {
        if limit == 0 {
            return Ok(Vec::new());
        }

        self.ensure_initialized()?;
        let client = self.client()?.clone();
        let table = self.qualified_table();
        run_blocking(async move {
            let rows = sqlx::query(&format!(
                "SELECT id, recorded_at_unix_seconds, app_id, trace_id, request_id, principal_kind, principal_id, kind FROM {} ORDER BY recorded_at_unix_seconds DESC, id DESC LIMIT $1",
                table
            ))
            .bind(limit as i64)
            .fetch_all(&client.pool)
            .await
            .map_err(|error| format!("failed to query shared metadata audit entries: {error}"))?;

            let mut records = rows
                .into_iter()
                .map(|row| {
                    Ok(MetadataAuditRecord {
                        id: row.try_get(0).map_err(|error| {
                            format!("failed to decode shared metadata audit entry id: {error}")
                        })?,
                        recorded_at_unix_seconds: row.try_get(1).map_err(|error| {
                            format!("failed to decode shared metadata audit timestamp: {error}")
                        })?,
                        app_id: row.try_get(2).map_err(|error| {
                            format!("failed to decode shared metadata audit app id: {error}")
                        })?,
                        trace_id: row.try_get(3).map_err(|error| {
                            format!("failed to decode shared metadata audit trace id: {error}")
                        })?,
                        request_id: row.try_get(4).map_err(|error| {
                            format!("failed to decode shared metadata audit request id: {error}")
                        })?,
                        principal_kind: row.try_get(5).map_err(|error| {
                            format!(
                                "failed to decode shared metadata audit principal kind: {error}"
                            )
                        })?,
                        principal_id: row.try_get(6).map_err(|error| {
                            format!("failed to decode shared metadata audit principal id: {error}")
                        })?,
                        kind: row.try_get(7).map_err(|error| {
                            format!("failed to decode shared metadata audit kind: {error}")
                        })?,
                    })
                })
                .collect::<Result<Vec<_>, String>>()?;
            records.reverse();
            Ok(records)
        })
    }

    pub(super) fn upsert_customer_managed_asset(
        &self,
        logical_path: &str,
        record_json: &str,
        updated_at_unix_seconds: i64,
    ) -> Result<(), String> {
        self.ensure_initialized()?;
        let client = self.client()?.clone();
        let table = self.qualified_customer_managed_assets_table();
        let logical_path = logical_path.to_string();
        let record_json = record_json.to_string();
        run_blocking(async move {
            sqlx::query(&format!(
                "INSERT INTO {} (logical_path, record_json, updated_at_unix_seconds) VALUES ($1, $2, $3)
                 ON CONFLICT (logical_path) DO UPDATE SET
                    record_json = EXCLUDED.record_json,
                    updated_at_unix_seconds = EXCLUDED.updated_at_unix_seconds",
                table
            ))
            .bind(&logical_path)
            .bind(&record_json)
            .bind(updated_at_unix_seconds)
            .execute(&client.pool)
            .await
            .map_err(|error| {
                format!("failed to write shared customer managed asset `{logical_path}`: {error}")
            })?;
            Ok(())
        })
    }

    pub(super) fn customer_managed_asset(
        &self,
        logical_path: &str,
    ) -> Result<Option<String>, String> {
        self.ensure_initialized()?;
        let client = self.client()?.clone();
        let table = self.qualified_customer_managed_assets_table();
        let logical_path = logical_path.to_string();
        run_blocking(async move {
            let row = sqlx::query(&format!(
                "SELECT record_json FROM {} WHERE logical_path = $1",
                table
            ))
            .bind(&logical_path)
            .fetch_optional(&client.pool)
            .await
            .map_err(|error| {
                format!("failed to query shared customer managed asset `{logical_path}`: {error}")
            })?;
            match row {
                Some(row) => row.try_get(0).map(Some).map_err(|error| {
                    format!(
                        "failed to decode shared customer managed asset `{logical_path}`: {error}"
                    )
                }),
                None => Ok(None),
            }
        })
    }

    fn client(&self) -> Result<&PostgresDataClient, String> {
        self.client
            .get_or_init(|| {
                self.runtime
                    .connect_lazy_postgres()
                    .map_err(|error| error.to_string())
            })
            .as_ref()
            .map_err(|error| error.clone())
    }

    fn ensure_initialized(&self) -> Result<(), String> {
        let schema_ident = quote_identifier(&self.schema);
        self.initialized
            .get_or_init(|| {
                let client = self.client()?.clone();
                run_blocking(async move {
                    sqlx::query(&format!("CREATE SCHEMA IF NOT EXISTS {schema_ident}"))
                        .execute(&client.pool)
                        .await
                        .map_err(|error| format!("failed to initialize shared metadata schema: {error}"))?;

                    sqlx::query(&format!(
                        "CREATE TABLE IF NOT EXISTS {schema_ident}.metadata_audit_entries (
                            id BIGSERIAL PRIMARY KEY,
                            recorded_at_unix_seconds BIGINT NOT NULL,
                            app_id TEXT NOT NULL,
                            trace_id TEXT NOT NULL,
                            request_id TEXT,
                            principal_kind TEXT NOT NULL,
                            principal_id TEXT,
                            kind TEXT NOT NULL
                        )"
                    ))
                    .execute(&client.pool)
                    .await
                    .map_err(|error| format!("failed to initialize shared metadata audit table: {error}"))?;

                    sqlx::query(&format!(
                        "CREATE INDEX IF NOT EXISTS metadata_audit_entries_recent
                            ON {schema_ident}.metadata_audit_entries (recorded_at_unix_seconds DESC, id DESC)"
                    ))
                    .execute(&client.pool)
                    .await
                    .map_err(|error| format!("failed to initialize shared metadata audit index: {error}"))?;

                    sqlx::query(&format!(
                        "CREATE TABLE IF NOT EXISTS {schema_ident}.customer_managed_assets (
                            logical_path TEXT PRIMARY KEY,
                            record_json TEXT NOT NULL,
                            updated_at_unix_seconds BIGINT NOT NULL
                        )"
                    ))
                    .execute(&client.pool)
                    .await
                    .map_err(|error| format!("failed to initialize shared customer managed assets table: {error}"))?;

                    sqlx::query(&format!(
                        "CREATE INDEX IF NOT EXISTS customer_managed_assets_recent
                            ON {schema_ident}.customer_managed_assets (updated_at_unix_seconds DESC, logical_path DESC)"
                    ))
                    .execute(&client.pool)
                    .await
                    .map_err(|error| format!("failed to initialize shared customer managed assets index: {error}"))?;

                    Ok(())
                })
            })
            .clone()
    }

    fn qualified_table(&self) -> String {
        format!(
            "{}.{}",
            quote_identifier(&self.schema),
            quote_identifier("metadata_audit_entries")
        )
    }

    fn qualified_customer_managed_assets_table(&self) -> String {
        format!(
            "{}.{}",
            quote_identifier(&self.schema),
            quote_identifier("customer_managed_assets")
        )
    }
}

fn quote_identifier(identifier: &str) -> String {
    format!("\"{}\"", identifier.replace('"', "\"\""))
}

fn run_blocking<T, F>(future: F) -> Result<T, String>
where
    T: Send + 'static,
    F: std::future::Future<Output = Result<T, String>> + Send + 'static,
{
    match tokio::runtime::Handle::try_current() {
        Ok(handle) => match handle.runtime_flavor() {
            tokio::runtime::RuntimeFlavor::MultiThread => {
                tokio::task::block_in_place(|| handle.block_on(future))
            }
            tokio::runtime::RuntimeFlavor::CurrentThread => run_future_on_dedicated_runtime(future),
            _ => run_future_on_dedicated_runtime(future),
        },
        Err(_) => run_future_on_ephemeral_runtime(future),
    }
}

fn run_future_on_dedicated_runtime<T, F>(future: F) -> Result<T, String>
where
    T: Send + 'static,
    F: std::future::Future<Output = Result<T, String>> + Send + 'static,
{
    std::thread::spawn(move || {
        let runtime = tokio::runtime::Builder::new_current_thread()
            .enable_all()
            .build()
            .map_err(|error| error.to_string())?;
        runtime.block_on(future)
    })
    .join()
    .map_err(|_| "shared metadata worker thread panicked".to_string())?
}

fn run_future_on_ephemeral_runtime<T, F>(future: F) -> Result<T, String>
where
    T: Send + 'static,
    F: std::future::Future<Output = Result<T, String>> + Send + 'static,
{
    let runtime = tokio::runtime::Builder::new_current_thread()
        .enable_all()
        .build()
        .map_err(|error| error.to_string())?;
    runtime.block_on(future)
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::time::Duration;

    #[test]
    fn shared_metadata_backend_labels_the_selected_backend_and_location() {
        let runtime = DataRuntime {
            driver: coil_config::DatabaseDriver::Postgres,
            connection_secret_ref: None,
            connection_secret: None,
            schema: "public".to_string(),
            migrations_table: "migrations".to_string(),
            pool: coil_data::ConnectionPoolProfile {
                min_connections: 1,
                max_connections: 4,
                statement_timeout: Duration::from_secs(30),
            },
        };
        let backend = SharedMetadataAuditStore::open(runtime);

        assert_eq!(
            backend.location_label(),
            "shared-postgres:public.metadata_audit_entries"
        );
    }

    #[test]
    fn shared_metadata_run_blocking_works_inside_current_thread_runtime() {
        let runtime = tokio::runtime::Builder::new_current_thread()
            .enable_all()
            .build()
            .unwrap();

        let value = runtime.block_on(async { run_blocking(async { Ok::<_, String>(7usize) }) });

        assert_eq!(value.unwrap(), 7);
    }
}