polyglot-sql 0.5.2

SQL parsing, validating, formatting, and dialect translation library
Documentation
use polyglot_sql::{
    analyze_query, scope::SourceKind, AnalyzeQueryOptions, DialectType, QueryShape,
    ReferenceConfidence, TransformKind, ValidationSchema,
};
use serde_json::json;

fn schema() -> ValidationSchema {
    serde_json::from_value(json!({
        "tables": [
            {
                "name": "users",
                "columns": [
                    {"name": "id", "type": "INT"},
                    {"name": "name", "type": "TEXT"}
                ]
            },
            {
                "name": "orders",
                "columns": [
                    {"name": "id", "type": "INT"},
                    {"name": "user_id", "type": "INT"},
                    {"name": "customer_id", "type": "INT"},
                    {"name": "amount", "type": "DECIMAL(10,2)"},
                    {"name": "total", "type": "FLOAT"}
                ]
            },
            {
                "name": "customers",
                "columns": [
                    {"name": "id", "type": "INT"},
                    {"name": "name", "type": "TEXT"}
                ]
            },
            {
                "name": "x",
                "columns": [{"name": "a", "type": "INT"}]
            },
            {
                "name": "y",
                "columns": [{"name": "b", "type": "INT"}]
            }
        ]
    }))
    .unwrap()
}

#[test]
fn analyze_query_reports_projection_relations_and_types() {
    let analysis = analyze_query(
        "SELECT u.id, CAST(o.total AS TEXT) AS total_text, 1 AS one \
         FROM users AS u JOIN orders AS o ON u.id = o.user_id",
        AnalyzeQueryOptions {
            dialect: DialectType::Generic,
            schema: Some(schema()),
        },
    )
    .unwrap();

    assert_eq!(analysis.shape, QueryShape::Select);
    assert_eq!(analysis.projections.len(), 3);
    assert_eq!(analysis.projections[0].name.as_deref(), Some("id"));
    assert_eq!(
        analysis.projections[0].transform_kind,
        TransformKind::Direct
    );
    assert_eq!(analysis.projections[1].transform_kind, TransformKind::Cast);
    assert_eq!(analysis.projections[1].cast_type.as_deref(), Some("TEXT"));
    assert_eq!(
        analysis.projections[2].transform_kind,
        TransformKind::Constant
    );

    assert!(analysis
        .relations
        .iter()
        .any(|relation| relation.name == "users" && relation.alias.as_deref() == Some("u")));
    assert!(analysis
        .relations
        .iter()
        .any(|relation| relation.name == "orders" && relation.columns.contains(&"total".into())));
    assert!(analysis
        .base_tables
        .iter()
        .any(|relation| relation.name == "orders"));
    assert!(analysis
        .base_tables
        .iter()
        .any(|relation| relation.name == "users"));

    let total = &analysis.projections[1].upstream;
    assert!(total.iter().any(|reference| {
        reference.table.as_deref() == Some("orders")
            && reference.column == "total"
            && reference.confidence == ReferenceConfidence::Resolved
    }));
}

#[test]
fn analyze_query_follows_cte_lineage() {
    let analysis = analyze_query(
        "WITH base AS (SELECT id FROM users) SELECT id FROM base",
        AnalyzeQueryOptions {
            dialect: DialectType::Generic,
            schema: Some(schema()),
        },
    )
    .unwrap();

    assert_eq!(analysis.ctes, vec!["base"]);
    assert!(analysis.projections[0].upstream.iter().any(|reference| {
        reference.table.as_deref() == Some("users") && reference.column == "id"
    }));
    assert_eq!(analysis.base_tables.len(), 1);
    assert_eq!(analysis.base_tables[0].name, "users");
}

#[test]
fn analyze_query_reports_set_operations() {
    let analysis = analyze_query(
        "SELECT a FROM x UNION ALL SELECT b FROM y",
        AnalyzeQueryOptions {
            dialect: DialectType::Generic,
            schema: Some(schema()),
        },
    )
    .unwrap();

    assert_eq!(analysis.shape, QueryShape::SetOperation);
    assert_eq!(analysis.set_operations.len(), 1);
    assert_eq!(analysis.set_operations[0].kind, "union");
    assert!(analysis.set_operations[0].all);
    assert_eq!(analysis.set_operations[0].output_columns, vec!["a"]);
    assert_eq!(analysis.set_operations[0].branches.len(), 2);
    assert!(analysis
        .relations
        .iter()
        .any(|relation| relation.name == "x"));
    assert!(analysis
        .relations
        .iter()
        .any(|relation| relation.name == "y"));
    assert!(analysis
        .base_tables
        .iter()
        .any(|relation| relation.name == "x"));
    assert!(analysis
        .base_tables
        .iter()
        .any(|relation| relation.name == "y"));
    assert_eq!(
        analysis.set_operations[0].branches[0].projections[0]
            .name
            .as_deref(),
        Some("a")
    );
}

#[test]
fn analyze_query_rejects_non_query_statements() {
    let err = analyze_query(
        "CREATE TABLE t (a INT)",
        AnalyzeQueryOptions {
            dialect: DialectType::Generic,
            schema: None,
        },
    )
    .unwrap_err();

    assert!(err.to_string().contains("requires a SELECT"));
}

#[test]
fn analyze_query_preserves_physical_table_aliases_in_lineage() {
    let analysis = analyze_query(
        "SELECT o.id FROM orders AS o",
        AnalyzeQueryOptions {
            dialect: DialectType::Generic,
            schema: Some(schema()),
        },
    )
    .unwrap();

    let reference = analysis.projections[0]
        .upstream
        .iter()
        .find(|reference| reference.column == "id")
        .unwrap();
    assert_eq!(reference.source_name.as_deref(), Some("orders"));
    assert_eq!(reference.source_alias.as_deref(), Some("o"));
    assert_eq!(reference.table.as_deref(), Some("orders"));
}

#[test]
fn analyze_query_limits_qualified_star_to_matching_source() {
    let analysis = analyze_query(
        "SELECT o.* FROM orders AS o JOIN customers AS c ON o.customer_id = c.id",
        AnalyzeQueryOptions {
            dialect: DialectType::Generic,
            schema: Some(schema()),
        },
    )
    .unwrap();

    assert_eq!(analysis.projections.len(), 5);
    let mut projection_names: Vec<_> = analysis
        .projections
        .iter()
        .filter_map(|projection| projection.name.as_deref())
        .collect();
    projection_names.sort_unstable();
    assert_eq!(
        projection_names,
        vec!["amount", "customer_id", "id", "total", "user_id"]
    );
    assert!(analysis
        .projections
        .iter()
        .all(|projection| !projection.is_star));
    assert!(analysis.projections.iter().all(|projection| {
        projection
            .upstream
            .iter()
            .all(|reference| reference.table.as_deref() == Some("orders"))
    }));
}

#[test]
fn analyze_query_resolves_unique_unqualified_columns_with_alias_schema() {
    let analysis = analyze_query(
        "SELECT amount FROM orders AS o JOIN customers AS c ON o.customer_id = c.id",
        AnalyzeQueryOptions {
            dialect: DialectType::Generic,
            schema: Some(schema()),
        },
    )
    .unwrap();

    let reference = analysis.projections[0]
        .upstream
        .iter()
        .find(|reference| reference.column == "amount")
        .unwrap();
    assert_eq!(reference.table.as_deref(), Some("orders"));
    assert_eq!(reference.source_alias.as_deref(), Some("o"));
    assert_eq!(reference.confidence, ReferenceConfidence::Resolved);
}

#[test]
fn analyze_query_preserves_precise_schema_type_hints() {
    let analysis = analyze_query(
        "SELECT amount FROM orders",
        AnalyzeQueryOptions {
            dialect: DialectType::Generic,
            schema: Some(schema()),
        },
    )
    .unwrap();

    assert_eq!(
        analysis.projections[0].type_hint.as_deref(),
        Some("DECIMAL(10, 2)")
    );
}

#[test]
fn analyze_query_expands_unqualified_star_with_schema() {
    let analysis = analyze_query(
        "SELECT * FROM orders",
        AnalyzeQueryOptions {
            dialect: DialectType::Generic,
            schema: Some(schema()),
        },
    )
    .unwrap();

    let mut names: Vec<_> = analysis
        .projections
        .iter()
        .filter_map(|projection| projection.name.as_deref())
        .collect();
    names.sort_unstable();
    assert_eq!(
        names,
        vec!["amount", "customer_id", "id", "total", "user_id"]
    );
    assert!(analysis
        .projections
        .iter()
        .all(|projection| !projection.is_star));
}

#[test]
fn analyze_query_classifies_typed_aggregates() {
    let analysis = analyze_query(
        "SELECT COUNT(*) AS rows, SUM(amount) AS amount_sum FROM orders",
        AnalyzeQueryOptions {
            dialect: DialectType::Generic,
            schema: Some(schema()),
        },
    )
    .unwrap();

    assert_eq!(
        analysis.projections[0].transform_kind,
        TransformKind::Aggregation
    );
    assert_eq!(
        analysis.projections[1].transform_kind,
        TransformKind::Aggregation
    );
}

#[test]
fn analyze_query_reports_transitive_base_tables() {
    let analysis = analyze_query(
        "WITH paid AS (SELECT customer_id FROM orders) \
         SELECT c.name FROM customers AS c \
         JOIN (SELECT customer_id FROM paid) AS p ON c.id = p.customer_id",
        AnalyzeQueryOptions {
            dialect: DialectType::Generic,
            schema: Some(schema()),
        },
    )
    .unwrap();

    let base_table_names: Vec<_> = analysis
        .base_tables
        .iter()
        .map(|relation| relation.name.as_str())
        .collect();
    assert_eq!(base_table_names, vec!["customers", "orders"]);
    assert!(analysis
        .relations
        .iter()
        .any(|relation| relation.name == "customers"));
    assert!(analysis
        .relations
        .iter()
        .any(|relation| relation.kind == SourceKind::DerivedTable));
}