athena_rs 3.23.0

Hyper performant polyglot Database driver
Documentation
//! Cloudflare D1 schema introspection helpers.
//!
//! These helpers keep the public `/schema/*` response contracts intact while
//! executing D1-specific catalog queries through the proxied D1 driver.

use serde::de::DeserializeOwned;
use serde_json::{Value, json};

use crate::AppState;
use crate::api::response::{internal_error, service_unavailable};
use crate::drivers::cloudflare_d1::client::{D1ConnectionInfo, execute_query_via_proxy};

use super::service_contracts::{SchemaColumnRecord, SchemaConstraintRecord, SchemaRelationRecord};

const D1_RELATIONS_QUERY: &str = r#"
SELECT ?1 AS table_schema,
       m.name AS table_name,
       CASE m.type
           WHEN 'view' THEN 'VIEW'
           ELSE 'BASE TABLE'
       END AS relation_type
FROM sqlite_master AS m
WHERE m.type IN ('table', 'view')
  AND m.name NOT LIKE 'sqlite_%'
ORDER BY m.name
"#;

const D1_COLUMNS_FOR_TABLE_QUERY: &str = r#"
SELECT name AS column_name,
       type AS data_type,
       dflt_value AS column_default,
       CASE "notnull"
           WHEN 0 THEN 'YES'
           ELSE 'NO'
       END AS is_nullable
FROM pragma_table_info(?1)
ORDER BY cid
"#;

const D1_CONSTRAINTS_FOR_TABLE_QUERY: &str = r#"
SELECT idx.name AS constraint_name,
       ii.name AS column_name
FROM pragma_index_list(?1) AS idx
JOIN pragma_index_info(idx.name) AS ii
WHERE idx."unique" = 1
  AND idx.origin = 'u'
ORDER BY idx.name, ii.seqno
"#;

#[derive(Debug, Clone)]
pub(crate) struct D1SchemaOverviewRows {
    pub(crate) relations: Vec<SchemaRelationRecord>,
    pub(crate) columns: Vec<SchemaColumnRecord>,
}

pub(super) async fn load_schema_overview_rows(
    app_state: &AppState,
    connection_info: &D1ConnectionInfo,
    schema_label: &str,
) -> Result<D1SchemaOverviewRows, actix_web::HttpResponse> {
    let relations = load_schema_table_rows(app_state, connection_info, schema_label).await?;
    let columns = load_schema_column_rows(app_state, connection_info, schema_label, None).await?;
    Ok(D1SchemaOverviewRows { relations, columns })
}

pub(super) async fn load_schema_table_rows(
    app_state: &AppState,
    connection_info: &D1ConnectionInfo,
    schema_label: &str,
) -> Result<Vec<SchemaRelationRecord>, actix_web::HttpResponse> {
    let mut relations: Vec<SchemaRelationRecord> = execute_d1_query(
        app_state,
        connection_info,
        D1_RELATIONS_QUERY,
        vec![json!(schema_label)],
        "tables",
    )
    .await?;

    for relation in &mut relations {
        relation.table_schema = schema_label.to_string();
    }

    Ok(relations)
}

pub(super) async fn load_schema_column_rows(
    app_state: &AppState,
    connection_info: &D1ConnectionInfo,
    schema_label: &str,
    table_filter: Option<&str>,
) -> Result<Vec<SchemaColumnRecord>, actix_web::HttpResponse> {
    let table_names: Vec<String> = match table_filter {
        Some(table_name) => vec![table_name.to_string()],
        None => load_schema_table_rows(app_state, connection_info, schema_label)
            .await?
            .into_iter()
            .map(|relation| relation.table_name)
            .collect(),
    };

    let mut columns = Vec::new();
    for table_name in table_names {
        let mut table_columns: Vec<SchemaColumnRecord> = execute_d1_query(
            app_state,
            connection_info,
            D1_COLUMNS_FOR_TABLE_QUERY,
            vec![json!(table_name.clone())],
            "columns",
        )
        .await?;
        for column in &mut table_columns {
            column.table_schema = schema_label.to_string();
            column.table_name = table_name.clone();
        }
        columns.extend(table_columns);
    }

    Ok(columns)
}

pub(super) async fn load_schema_constraint_rows(
    app_state: &AppState,
    connection_info: &D1ConnectionInfo,
    table_name: &str,
) -> Result<Vec<SchemaConstraintRecord>, actix_web::HttpResponse> {
    execute_d1_query(
        app_state,
        connection_info,
        D1_CONSTRAINTS_FOR_TABLE_QUERY,
        vec![json!(table_name)],
        "constraints",
    )
    .await
}

async fn execute_d1_query<T>(
    app_state: &AppState,
    connection_info: &D1ConnectionInfo,
    query: &str,
    params: Vec<Value>,
    operation: &str,
) -> Result<Vec<T>, actix_web::HttpResponse>
where
    T: DeserializeOwned,
{
    let result = match execute_query_via_proxy(
        &app_state.client,
        connection_info,
        query,
        params,
        None,
        None,
        false,
    )
    .await
    {
        Ok(result) => result,
        Err(err) => {
            return Err(service_unavailable(
                format!("Failed to fetch {operation}"),
                format!("Failed to fetch {operation} from D1: {err}"),
            ));
        }
    };

    parse_d1_rows(result.rows, operation)
}

fn parse_d1_rows<T>(rows: Vec<Value>, operation: &str) -> Result<Vec<T>, actix_web::HttpResponse>
where
    T: DeserializeOwned,
{
    let mut parsed = Vec::with_capacity(rows.len());
    for row in rows {
        let value: T = match serde_json::from_value(row) {
            Ok(value) => value,
            Err(err) => {
                return Err(internal_error(
                    format!("Failed to decode D1 {operation} rows"),
                    format!("Failed to decode D1 {operation} rows: {err}"),
                ));
            }
        };
        parsed.push(value);
    }
    Ok(parsed)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::api::schema::service_contracts::{
        SchemaColumnRecord, SchemaConstraintRecord, SchemaRelationRecord,
    };

    #[test]
    fn relations_query_uses_sqlite_master() {
        assert!(D1_RELATIONS_QUERY.contains("sqlite_master"));
        assert!(D1_RELATIONS_QUERY.contains("BASE TABLE"));
    }

    #[test]
    fn columns_query_uses_table_info_pragma() {
        assert!(D1_COLUMNS_FOR_TABLE_QUERY.contains("pragma_table_info"));
        assert!(D1_COLUMNS_FOR_TABLE_QUERY.contains("is_nullable"));
    }

    #[test]
    fn constraints_query_uses_unique_index_pragmas() {
        assert!(D1_CONSTRAINTS_FOR_TABLE_QUERY.contains("pragma_index_list"));
        assert!(D1_CONSTRAINTS_FOR_TABLE_QUERY.contains("pragma_index_info"));
    }

    #[test]
    fn parses_relation_row_contracts_from_d1_json() {
        let rows = vec![json!({
            "table_schema": "public",
            "table_name": "users",
            "relation_type": "BASE TABLE"
        })];
        let parsed: Vec<SchemaRelationRecord> = parse_d1_rows(rows, "tables").expect("rows decode");
        assert_eq!(parsed[0].table_name, "users");
        assert_eq!(parsed[0].table_schema, "public");
    }

    #[test]
    fn parses_column_row_contracts_from_d1_json() {
        let rows = vec![json!({
            "table_schema": "public",
            "table_name": "users",
            "column_name": "id",
            "data_type": "INTEGER",
            "column_default": null,
            "is_nullable": "NO"
        })];
        let parsed: Vec<SchemaColumnRecord> = parse_d1_rows(rows, "columns").expect("rows decode");
        assert_eq!(parsed[0].column_name, "id");
        assert_eq!(parsed[0].is_nullable.as_deref(), Some("NO"));
    }

    #[test]
    fn parses_constraint_row_contracts_from_d1_json() {
        let rows = vec![json!({
            "constraint_name": "users_email_key",
            "column_name": "email"
        })];
        let parsed: Vec<SchemaConstraintRecord> =
            parse_d1_rows(rows, "constraints").expect("rows decode");
        assert_eq!(parsed[0].constraint_name, "users_email_key");
    }
}