athena-gateway 3.18.0

Portable gateway request contracts and normalization primitives for Athena
Documentation
use super::*;

pub(super) fn parse_column_expr(input: &str) -> Result<CompatibilityColumnExpr, String> {
    let trimmed = input.trim();
    if trimmed.is_empty() {
        return Err("missing column expression".to_string());
    }

    let mut index = 0usize;
    let mut segments = vec![parse_identifier_segment(trimmed, &mut index)?];
    while index < trimmed.len() {
        let bytes = trimmed.as_bytes();
        if bytes[index] == b'.' {
            index += 1;
            segments.push(parse_identifier_segment(trimmed, &mut index)?);
            continue;
        }
        break;
    }

    let cast = parse_optional_trailing_cast(trimmed, &mut index)?;
    if index != trimmed.len() {
        return Err("unsupported column expression in compatibility mode".to_string());
    }

    Ok(CompatibilityColumnExpr {
        path: segments.join("."),
        cast,
    })
}

pub(super) fn parse_json_value_literal(input: &str) -> Result<CompatibilityValueExpr, String> {
    let trimmed = input.trim();
    if trimmed.is_empty() {
        return Err("missing value".to_string());
    }

    let mut index = 0usize;
    let value = if trimmed.as_bytes()[0] == b'\'' {
        serde_json::Value::String(parse_string_literal(trimmed, &mut index)?)
    } else {
        let bytes = trimmed.as_bytes();
        let start = index;
        while index < bytes.len() && !bytes[index].is_ascii_whitespace() && bytes[index] != b':' {
            index += 1;
        }
        let token = &trimmed[start..index];
        if token.eq_ignore_ascii_case("null") {
            serde_json::Value::Null
        } else if token.eq_ignore_ascii_case("true") {
            serde_json::Value::Bool(true)
        } else if token.eq_ignore_ascii_case("false") {
            serde_json::Value::Bool(false)
        } else if let Ok(number) = token.parse::<i64>() {
            json!(number)
        } else if let Ok(number) = token.parse::<f64>() {
            match serde_json::Number::from_f64(number) {
                Some(number) => serde_json::Value::Number(number),
                None => return Err("invalid numeric literal".to_string()),
            }
        } else {
            return Err("unsupported literal in compatibility mode".to_string());
        }
    };

    let cast = parse_optional_trailing_cast(trimmed, &mut index)?;
    if index != trimmed.len() {
        return Err("unsupported literal in compatibility mode".to_string());
    }
    Ok(CompatibilityValueExpr { value, cast })
}

pub(super) fn parse_optional_trailing_cast(
    input: &str,
    index: &mut usize,
) -> Result<Option<String>, String> {
    let bytes = input.as_bytes();
    let mut cast = None;
    while *index < bytes.len() {
        while *index < bytes.len() && bytes[*index].is_ascii_whitespace() {
            *index += 1;
        }
        if bytes.get(*index) != Some(&b':') || bytes.get(*index + 1) != Some(&b':') {
            break;
        }
        *index += 2;
        let mut segments = vec![parse_identifier_segment(input, index)?];
        while *index < bytes.len() && bytes[*index] == b'.' {
            *index += 1;
            segments.push(parse_identifier_segment(input, index)?);
        }
        cast = Some(segments.join("."));
    }
    while *index < bytes.len() && bytes[*index].is_ascii_whitespace() {
        *index += 1;
    }
    Ok(cast)
}

pub(super) fn parse_identifier_segment(input: &str, index: &mut usize) -> Result<String, String> {
    let bytes = input.as_bytes();
    if *index >= bytes.len() {
        return Err("missing identifier".to_string());
    }

    if bytes[*index] == b'"' {
        *index += 1;
        let mut segment = String::new();
        while *index < bytes.len() {
            let current = bytes[*index];
            if current == b'"' {
                if bytes.get(*index + 1) == Some(&b'"') {
                    segment.push('"');
                    *index += 2;
                    continue;
                }
                *index += 1;
                return Ok(segment);
            }
            segment.push(current as char);
            *index += 1;
        }
        return Err("unterminated quoted identifier".to_string());
    }

    let start = *index;
    while *index < bytes.len() {
        let current = bytes[*index];
        if current.is_ascii_alphanumeric() || current == b'_' {
            *index += 1;
            continue;
        }
        break;
    }

    if *index == start {
        return Err("missing identifier".to_string());
    }

    Ok(input[start..*index].to_string())
}

pub(super) fn parse_string_literal(input: &str, index: &mut usize) -> Result<String, String> {
    let bytes = input.as_bytes();
    if bytes.get(*index) != Some(&b'\'') {
        return Err("expected string literal".to_string());
    }
    *index += 1;
    let mut value = String::new();
    while *index < bytes.len() {
        let current = bytes[*index];
        if current == b'\'' {
            if bytes.get(*index + 1) == Some(&b'\'') {
                value.push('\'');
                *index += 2;
                continue;
            }
            *index += 1;
            return Ok(value);
        }
        value.push(current as char);
        *index += 1;
    }
    Err("unterminated string literal".to_string())
}

pub(super) fn validate_foreign_key_hint(raw: &str) -> Result<(), String> {
    if let Some(stripped) = raw.strip_prefix("parent.") {
        validate_identifier(stripped, "foreign_key")?;
        return Ok(());
    }
    if let Some(stripped) = raw.strip_prefix("child.") {
        validate_identifier(stripped, "foreign_key")?;
        return Ok(());
    }
    validate_identifier(raw, "foreign_key")
}

pub(super) fn validate_identifier(identifier: &str, label: &str) -> Result<(), String> {
    sanitize_identifier(identifier)
        .map(|_| ())
        .ok_or_else(|| format!("{label} '{identifier}' must be a valid SQL identifier"))
}

pub(super) fn sanitize_identifier(identifier: &str) -> Option<String> {
    let trimmed = identifier.trim();
    if trimmed.is_empty() {
        return None;
    }
    if trimmed
        .bytes()
        .all(|byte| byte.is_ascii_alphanumeric() || byte == b'_')
    {
        Some(format!("\"{trimmed}\""))
    } else {
        None
    }
}

pub(super) fn render_table_ref(table: &GatewayRelationSelectTableRef) -> String {
    match table.schema_name.as_deref() {
        Some(schema_name) => format!("{schema_name}.{}", table.table_name),
        None => table.table_name.clone(),
    }
}