athena-gateway 3.18.0

Portable gateway request contracts and normalization primitives for Athena
Documentation
use super::*;

pub(super) fn parse_where_clause_for_compatibility(
    input: &str,
    alias_context: &CompatibilityAliasContext,
) -> Result<Vec<CompatibilityTargetedValue>, String> {
    if input.trim().is_empty() {
        return Ok(Vec::new());
    }

    let mut where_filters = Vec::new();
    for condition in split_top_level_keyword(input, "AND") {
        let parsed = if let Some(index) = find_top_level_keyword(&condition, "IN", 0) {
            let column = parse_column_expr(condition[..index].trim())?;
            let (target, column) = normalize_column_target(
                column,
                &alias_context.root,
                &alias_context.relations,
                true,
            )?;
            let raw_values = condition[index + 2..].trim();
            if !raw_values.starts_with('(') || !raw_values.ends_with(')') {
                return Err("IN filters must use parenthesized literal lists".to_string());
            }
            let values = split_top_level_commas(&raw_values[1..raw_values.len() - 1])
                .into_iter()
                .map(|value| parse_json_value_literal(&value))
                .collect::<Result<Vec<_>, _>>()?;
            (target, column, "in", values)
        } else {
            let operators = ["<>", "!=", "=", ">", "<"];
            let mut matched = None;
            for operator in operators {
                if let Some(index) = find_top_level_operator(&condition, operator) {
                    matched = Some((index, operator));
                    break;
                }
            }
            let Some((index, operator)) = matched else {
                return Err("unsupported WHERE clause in compatibility mode".to_string());
            };
            let column = parse_column_expr(condition[..index].trim())?;
            let (target, column) = normalize_column_target(
                column,
                &alias_context.root,
                &alias_context.relations,
                true,
            )?;
            let value = parse_json_value_literal(condition[index + operator.len()..].trim())?;
            let normalized_operator = match operator {
                "=" => "eq",
                "<>" | "!=" => "neq",
                ">" => "gt",
                "<" => "lt",
                _ => unreachable!(),
            };
            (target, column, normalized_operator, vec![value])
        };

        let (target, column_expr, operator, values) = parsed;
        let mut column_cast = column_expr.cast;
        let mut value_cast = compatibility_shared_value_cast(&values)?;
        normalize_uuid_text_comparison_casts(operator, &mut column_cast, &mut value_cast, &values);

        let mut filter = Map::new();
        filter.insert("column".to_string(), json!(column_expr.path));
        filter.insert("operator".to_string(), json!(operator));
        if operator == "in" {
            filter.insert(
                "values".to_string(),
                Value::Array(values.into_iter().map(|value| value.value).collect()),
            );
        } else {
            let value = values
                .into_iter()
                .next()
                .ok_or_else(|| "scalar compatibility filter is missing a value".to_string())?;
            filter.insert("value".to_string(), value.value);
        }
        if let Some(column_cast) = column_cast {
            filter.insert("column_cast".to_string(), json!(column_cast));
        }
        if let Some(value_cast) = value_cast {
            filter.insert("value_cast".to_string(), json!(value_cast));
        }
        where_filters.push(CompatibilityTargetedValue {
            target,
            value: Value::Object(filter),
        });
    }

    Ok(where_filters)
}

pub(super) fn parse_order_by_for_compatibility(
    input: &str,
    alias_context: &CompatibilityAliasContext,
) -> Result<Vec<CompatibilityTargetedValue>, String> {
    if input.trim().is_empty() {
        return Ok(Vec::new());
    }

    let mut order_by = Vec::new();
    for item in split_top_level_commas(input) {
        let trimmed = item.trim();
        if trimmed.is_empty() {
            continue;
        }

        let (column_raw, direction) = if let Some(space_index) = trimmed.rfind(char::is_whitespace)
        {
            let column_candidate = trimmed[..space_index].trim();
            let direction_candidate = trimmed[space_index..].trim();
            if direction_candidate.eq_ignore_ascii_case("asc")
                || direction_candidate.eq_ignore_ascii_case("desc")
            {
                (column_candidate, direction_candidate.to_ascii_lowercase())
            } else {
                (trimmed, "asc".to_string())
            }
        } else {
            (trimmed, "asc".to_string())
        };

        let column = parse_column_expr(column_raw)?;
        let (target, column) =
            normalize_column_target(column, &alias_context.root, &alias_context.relations, true)?;
        order_by.push(CompatibilityTargetedValue {
            target,
            value: json!({
                "column": column.path,
                "direction": direction,
            }),
        });
    }

    Ok(order_by)
}

pub(super) fn normalize_column_target(
    column: CompatibilityColumnExpr,
    root_binding: &CompatibilityTableBinding,
    relations: &[CompatibilityRelationBinding],
    allow_unqualified_root: bool,
) -> Result<(CompatibilityTargetScope, CompatibilityColumnExpr), String> {
    let CompatibilityColumnExpr { path, cast } = column;
    let parts = path.split('.').collect::<Vec<_>>();

    if parts.len() == 1 {
        if !allow_unqualified_root {
            return Err(format!(
                "column '{}' must be qualified with the base-table alias or relation alias",
                path
            ));
        }
        return Ok((
            CompatibilityTargetScope::Root,
            CompatibilityColumnExpr { path, cast },
        ));
    }

    if parts.len() >= 2 && root_binding.matches_simple_qualifier(parts[0]) {
        let normalized = parts[1..].join(".");
        if normalized.contains('.') {
            return Err(format!(
                "column '{}' resolves to a nested path on the base table, which is unsupported in compatibility mode",
                path
            ));
        }
        return Ok((
            CompatibilityTargetScope::Root,
            CompatibilityColumnExpr {
                path: normalized,
                cast,
            },
        ));
    }

    if parts.len() >= 3 && root_binding.matches_schema_table(parts[0], parts[1]) {
        let normalized = parts[2..].join(".");
        if normalized.contains('.') {
            return Err(format!(
                "column '{}' resolves to a nested path on the base table, which is unsupported in compatibility mode",
                path
            ));
        }
        return Ok((
            CompatibilityTargetScope::Root,
            CompatibilityColumnExpr {
                path: normalized,
                cast,
            },
        ));
    }

    if let Some((display_name, normalized)) = match_relation_column_path(relations, &parts)? {
        return Ok((
            CompatibilityTargetScope::Relation(display_name),
            CompatibilityColumnExpr {
                path: normalized,
                cast,
            },
        ));
    }

    Err(format!(
        "column '{}' must reference the base table or a projected relation alias",
        path
    ))
}

pub(super) fn match_relation_column_path(
    relations: &[CompatibilityRelationBinding],
    parts: &[&str],
) -> Result<Option<(String, String)>, String> {
    if parts.len() >= 2 {
        let matches = relations
            .iter()
            .filter(|relation| {
                relation.display_name.eq_ignore_ascii_case(parts[0])
                    || relation.binding.matches_simple_qualifier(parts[0])
            })
            .collect::<Vec<_>>();
        if matches.len() > 1 {
            return Err(format!(
                "relation qualifier '{}' is ambiguous across projected relations",
                parts[0]
            ));
        }
        if let Some(relation) = matches.first() {
            return Ok(Some((relation.display_name.clone(), parts[1..].join("."))));
        }
    }

    if parts.len() >= 3 {
        let matches = relations
            .iter()
            .filter(|relation| relation.binding.matches_schema_table(parts[0], parts[1]))
            .collect::<Vec<_>>();
        if matches.len() > 1 {
            return Err(format!(
                "relation qualifier '{}.{}' is ambiguous across projected relations",
                parts[0], parts[1]
            ));
        }
        if let Some(relation) = matches.first() {
            return Ok(Some((relation.display_name.clone(), parts[2..].join("."))));
        }
    }

    Ok(None)
}

pub(super) fn parse_i64_clause(input: &str, label: &str) -> Result<Option<i64>, String> {
    if input.trim().is_empty() {
        return Ok(None);
    }
    let value = input
        .trim()
        .parse::<i64>()
        .map_err(|_| format!("{label} must be an integer"))?;
    if value < 0 {
        return Err(format!("{label} must be greater than or equal to 0"));
    }
    Ok(Some(value))
}

pub(super) fn compatibility_cast_base(raw_cast: &str) -> &str {
    raw_cast.rsplit('.').next().unwrap_or(raw_cast)
}

pub(super) fn compatibility_cast_is_text_or_uuid(raw_cast: Option<&str>) -> bool {
    raw_cast
        .map(compatibility_cast_base)
        .map(|value| value.eq_ignore_ascii_case("text") || value.eq_ignore_ascii_case("uuid"))
        .unwrap_or(false)
}

pub(super) fn normalize_uuid_text_comparison_casts(
    operator: &str,
    column_cast: &mut Option<String>,
    value_cast: &mut Option<String>,
    values: &[CompatibilityValueExpr],
) {
    if !matches!(operator, "eq" | "neq" | "in") {
        return;
    }
    if !values
        .iter()
        .all(|value| matches!(value.value, serde_json::Value::String(_)))
    {
        return;
    }
    let uses_uuid_or_text_cast = compatibility_cast_is_text_or_uuid(column_cast.as_deref())
        || values
            .iter()
            .any(|value| compatibility_cast_is_text_or_uuid(value.cast.as_deref()))
        || compatibility_cast_is_text_or_uuid(value_cast.as_deref());
    if uses_uuid_or_text_cast {
        *column_cast = Some("text".to_string());
        *value_cast = None;
    }
}

pub(super) fn compatibility_shared_value_cast(
    values: &[CompatibilityValueExpr],
) -> Result<Option<String>, String> {
    let mut casts = values.iter().map(|value| value.cast.as_deref());
    let Some(first) = casts.next() else {
        return Ok(None);
    };
    if casts.all(|candidate| candidate == first) {
        return Ok(first.map(str::to_string));
    }
    Err("compatibility IN filters require a consistent cast across all values".to_string())
}

pub(super) fn find_top_level_operator(input: &str, operator: &str) -> Option<usize> {
    scan_sql_until_top_level_boundary(
        input,
        0,
        |index, _current, _previous, _next, _paren_depth| {
            slice_eq_ignore_ascii_case(input, index, operator)
        },
    )
}