athena-gateway 3.18.0

Portable gateway request contracts and normalization primitives for Athena
Documentation
use super::*;

pub(super) fn find_join_relation_index(
    projections: &[CompatibilityProjectionItem],
    join: &CompatibilityJoinClause,
) -> Result<usize, String> {
    let candidates = projections
        .iter()
        .enumerate()
        .filter_map(|(index, projection)| match projection {
            CompatibilityProjectionItem::Relation(relation)
                if relation.table == join.binding.table =>
            {
                Some(index)
            }
            _ => None,
        })
        .collect::<Vec<_>>();

    match candidates.as_slice() {
        [] => Err(format!(
            "JOIN target '{}' does not match any projected relation token",
            render_table_ref(&join.binding.table)
        )),
        [index] => Ok(*index),
        _ => {
            let narrowed = candidates
                .into_iter()
                .filter(|index| match &projections[*index] {
                    CompatibilityProjectionItem::Relation(relation) => join
                        .binding
                        .alias
                        .as_deref()
                        .map(|alias| relation.display_name().eq_ignore_ascii_case(alias))
                        .unwrap_or(false),
                    CompatibilityProjectionItem::Column(_) => false,
                })
                .collect::<Vec<_>>();

            match narrowed.as_slice() {
                [index] => Ok(*index),
                _ => Err(format!(
                    "JOIN target '{}' is ambiguous across projected relation tokens",
                    render_table_ref(&join.binding.table)
                )),
            }
        }
    }
}

pub(super) fn derive_join_foreign_key(
    predicate: &CompatibilityJoinPredicate,
    root_binding: &CompatibilityTableBinding,
    join_binding: &CompatibilityTableBinding,
) -> Result<String, String> {
    let left = resolve_join_operand(&predicate.left, root_binding, join_binding)?;
    let right = resolve_join_operand(&predicate.right, root_binding, join_binding)?;

    let (root_column, join_column) = match (left.is_root, right.is_root) {
        (true, false) => (left.column, right.column),
        (false, true) => (right.column, left.column),
        _ => {
            return Err(
                "compatibility JOIN ON clauses must compare the base table to the joined relation"
                    .to_string(),
            );
        }
    };

    if root_column.eq_ignore_ascii_case("id") && !join_column.eq_ignore_ascii_case("id") {
        return Ok(format!("child.{join_column}"));
    }
    if join_column.eq_ignore_ascii_case("id") && !root_column.eq_ignore_ascii_case("id") {
        return Ok(format!("parent.{root_column}"));
    }

    Err(
        "compatibility JOIN ON clauses must compare one side's id column to the other side's foreign key"
            .to_string(),
    )
}

struct CompatibilityResolvedOperand {
    is_root: bool,
    column: String,
}

fn resolve_join_operand(
    column: &CompatibilityColumnExpr,
    root_binding: &CompatibilityTableBinding,
    join_binding: &CompatibilityTableBinding,
) -> Result<CompatibilityResolvedOperand, String> {
    if column.cast.is_some() {
        return Err("JOIN ON clauses do not support casts".to_string());
    }

    let parts = column.path.split('.').collect::<Vec<_>>();
    match parts.as_slice() {
        [qualifier, column_name] => {
            if root_binding.matches_simple_qualifier(qualifier) {
                return Ok(CompatibilityResolvedOperand {
                    is_root: true,
                    column: (*column_name).to_string(),
                });
            }
            if join_binding.matches_simple_qualifier(qualifier) {
                return Ok(CompatibilityResolvedOperand {
                    is_root: false,
                    column: (*column_name).to_string(),
                });
            }
        }
        [schema_name, table_name, column_name] => {
            if root_binding.matches_schema_table(schema_name, table_name) {
                return Ok(CompatibilityResolvedOperand {
                    is_root: true,
                    column: (*column_name).to_string(),
                });
            }
            if join_binding.matches_schema_table(schema_name, table_name) {
                return Ok(CompatibilityResolvedOperand {
                    is_root: false,
                    column: (*column_name).to_string(),
                });
            }
        }
        _ => {}
    }

    Err(format!(
        "JOIN ON column '{}' must be qualified with the base-table alias or joined-table alias",
        column.path
    ))
}

pub(super) fn parse_from_clause_for_compatibility(
    input: &str,
) -> Result<(CompatibilityTableBinding, Vec<CompatibilityJoinClause>), String> {
    let trimmed = input.trim();
    if trimmed.is_empty() {
        return Err("missing table reference".to_string());
    }

    let mut index = 0usize;
    let root = parse_table_binding(trimmed, &mut index)?;
    let mut joins = Vec::new();

    loop {
        skip_ascii_whitespace(trimmed, &mut index);
        if index >= trimmed.len() {
            break;
        }
        if trimmed.as_bytes()[index] == b',' {
            return Err(
                "comma joins are not supported in relation-select compatibility mode".to_string(),
            );
        }

        let kind = parse_join_kind(trimmed, &mut index)?;
        let binding = parse_table_binding(trimmed, &mut index)?;
        skip_ascii_whitespace(trimmed, &mut index);

        if starts_with_keyword_at(trimmed, index, "USING") {
            return Err(
                "JOIN USING(...) is not supported in relation-select compatibility mode"
                    .to_string(),
            );
        }
        if !starts_with_keyword_at(trimmed, index, "ON") {
            return Err(
                "JOIN clauses require a single ON left.col = right.col predicate".to_string(),
            );
        }
        index += "ON".len();

        let on_start = index;
        let on_end = find_next_join_boundary(trimmed, index).unwrap_or_else(|| trimmed.len());
        let on = parse_join_on_clause(trimmed[on_start..on_end].trim())?;
        joins.push(CompatibilityJoinClause { kind, binding, on });
        index = on_end;
    }

    Ok((root, joins))
}

pub(super) fn find_next_join_boundary(input: &str, start: usize) -> Option<usize> {
    [
        "LEFT OUTER JOIN",
        "LEFT JOIN",
        "RIGHT JOIN",
        "FULL JOIN",
        "CROSS JOIN",
        "INNER JOIN",
        "JOIN",
    ]
    .iter()
    .filter_map(|keyword| {
        find_top_level_keyword(input, keyword, start).map(|index| (index, *keyword))
    })
    .min_by_key(|(index, _)| *index)
    .map(|(index, _)| index)
}

pub(super) fn parse_join_kind(
    input: &str,
    index: &mut usize,
) -> Result<CompatibilityJoinKind, String> {
    skip_ascii_whitespace(input, index);

    if starts_with_keyword_at(input, *index, "LEFT OUTER JOIN") {
        *index += "LEFT OUTER JOIN".len();
        return Ok(CompatibilityJoinKind::Left);
    }
    if starts_with_keyword_at(input, *index, "LEFT JOIN") {
        *index += "LEFT JOIN".len();
        return Ok(CompatibilityJoinKind::Left);
    }
    if starts_with_keyword_at(input, *index, "INNER JOIN") {
        *index += "INNER JOIN".len();
        return Ok(CompatibilityJoinKind::Inner);
    }
    if starts_with_keyword_at(input, *index, "JOIN") {
        *index += "JOIN".len();
        return Ok(CompatibilityJoinKind::Inner);
    }
    if starts_with_keyword_at(input, *index, "RIGHT JOIN") {
        return Err(
            "RIGHT JOIN is not supported in relation-select compatibility mode".to_string(),
        );
    }
    if starts_with_keyword_at(input, *index, "FULL JOIN") {
        return Err("FULL JOIN is not supported in relation-select compatibility mode".to_string());
    }
    if starts_with_keyword_at(input, *index, "CROSS JOIN") {
        return Err(
            "CROSS JOIN is not supported in relation-select compatibility mode".to_string(),
        );
    }

    Err("unsupported JOIN clause in relation-select compatibility mode".to_string())
}

pub(super) fn parse_table_binding(
    input: &str,
    index: &mut usize,
) -> Result<CompatibilityTableBinding, String> {
    let table = parse_table_ref_from_index(input, index)?;
    skip_ascii_whitespace(input, index);

    let alias = if starts_with_keyword_at(input, *index, "AS") {
        *index += "AS".len();
        skip_ascii_whitespace(input, index);
        Some(parse_identifier_segment(input, index)?)
    } else if starts_with_any_keyword_at(
        input,
        *index,
        &[
            "LEFT OUTER JOIN",
            "LEFT JOIN",
            "RIGHT JOIN",
            "FULL JOIN",
            "CROSS JOIN",
            "INNER JOIN",
            "JOIN",
            "ON",
        ],
    ) || *index >= input.len()
    {
        None
    } else {
        Some(parse_identifier_segment(input, index)?)
    };

    Ok(CompatibilityTableBinding { table, alias })
}

pub(super) fn parse_table_ref(input: &str) -> Result<GatewayRelationSelectTableRef, String> {
    let mut index = 0usize;
    let table = parse_table_ref_from_index(input.trim(), &mut index)?;
    if input.trim()[index..].trim().is_empty() {
        Ok(table)
    } else {
        Err("table reference must be 'table' or 'schema.table'".to_string())
    }
}

pub(super) fn parse_table_ref_from_index(
    input: &str,
    index: &mut usize,
) -> Result<GatewayRelationSelectTableRef, String> {
    skip_ascii_whitespace(input, index);
    let first = parse_identifier_segment(input, index)?;
    let mut segments = vec![first];

    loop {
        if *index >= input.len() || input.as_bytes()[*index] != b'.' {
            break;
        }
        *index += 1;
        segments.push(parse_identifier_segment(input, index)?);
    }

    match segments.as_slice() {
        [table_name] => Ok(GatewayRelationSelectTableRef {
            schema_name: None,
            table_name: table_name.clone(),
        }),
        [schema_name, table_name] => Ok(GatewayRelationSelectTableRef {
            schema_name: Some(schema_name.clone()),
            table_name: table_name.clone(),
        }),
        _ => Err("table reference must be 'table' or 'schema.table'".to_string()),
    }
}

pub(super) fn parse_join_on_clause(input: &str) -> Result<CompatibilityJoinPredicate, String> {
    if input.is_empty() {
        return Err("JOIN clauses require an ON predicate".to_string());
    }
    if find_top_level_keyword(input, "AND", 0).is_some()
        || find_top_level_keyword(input, "OR", 0).is_some()
    {
        return Err("JOIN ON clauses support only a single equality predicate".to_string());
    }

    let Some(eq_index) = find_single_top_level_equality(input) else {
        return Err("JOIN ON clauses require a single left.col = right.col predicate".to_string());
    };
    if find_single_top_level_equality(&input[eq_index + 1..]).is_some() {
        return Err("JOIN ON clauses support only a single equality predicate".to_string());
    }

    let left = parse_column_expr(input[..eq_index].trim())?;
    let right = parse_column_expr(input[eq_index + 1..].trim())?;
    if left.cast.is_some() || right.cast.is_some() {
        return Err("JOIN ON clauses do not support casts".to_string());
    }

    Ok(CompatibilityJoinPredicate { left, right })
}

pub(super) fn find_single_top_level_equality(input: &str) -> Option<usize> {
    scan_sql_until_top_level_boundary(input, 0, |_index, current, previous, next, paren_depth| {
        paren_depth == 0
            && current == b'='
            && previous != Some(b'<')
            && previous != Some(b'>')
            && previous != Some(b'!')
            && previous != Some(b'=')
            && next != Some(b'=')
    })
}