athena_rs 3.3.0

Database gateway API
Documentation
use actix_web::HttpRequest;
use serde_json::Value;
use serde_urlencoded::from_str;

use crate::utils::format::normalize_column_name;

#[derive(Debug)]
pub struct PostgrestQuery {
    pub columns: Vec<String>,
    pub filters: Vec<PostgrestFilter>,
    pub or_filters: Vec<Vec<PostgrestFilter>>,
    pub limit: Option<i64>,
    pub offset: Option<i64>,
    pub order: Option<OrderSpec>,
}

#[derive(Debug)]
pub struct PostgrestFilter {
    pub column: String,
    pub operator: FilterOperator,
    pub values: Vec<Value>,
    pub negated: bool,
}

#[derive(Debug, Clone, Copy)]
pub enum FilterOperator {
    Eq,
    Neq,
    Gt,
    Lt,
    Gte,
    Lte,
    Like,
    ILike,
    Is,
    In,
    Contains,
    Contained,
}

#[derive(Debug)]
pub struct OrderSpec {
    pub column: String,
    pub ascending: bool,
}

/// Parses a Supabase/PostgREST query into a normalized representation that can be
/// mapped to Athena/Postgres operations.
pub fn parse_postgrest_query(
    _table_name: &str,
    req: &HttpRequest,
    force_snake_case: bool,
) -> Result<PostgrestQuery, String> {
    let mut columns: Vec<String> = vec!["*".to_string()];
    let mut filters: Vec<PostgrestFilter> = Vec::new();
    let mut or_filters: Vec<Vec<PostgrestFilter>> = Vec::new();
    let mut limit: Option<i64> = None;
    let mut offset: Option<i64> = None;
    let mut order: Option<OrderSpec> = None;

    let query_pairs = from_str::<Vec<(String, String)>>(req.query_string())
        .map_err(|err| format!("failed to parse query string: {err}"))?;

    for (key, value) in query_pairs {
        match key.as_str() {
            "select" => {
                let parsed: Vec<String> = value
                    .split(',')
                    .map(|part| part.trim().to_string())
                    .filter(|part| !part.is_empty())
                    .collect();
                if !parsed.is_empty() {
                    columns = parsed;
                }
            }
            "limit" => {
                if let Ok(parsed) = value.parse::<i64>() {
                    limit = Some(parsed);
                }
            }
            "offset" => {
                if let Ok(parsed) = value.parse::<i64>() {
                    offset = Some(parsed);
                }
            }
            "order" => {
                if let Some(spec) = parse_order(&value, force_snake_case) {
                    order = Some(spec);
                }
            }
            "or" => {
                if let Some(parsed) = parse_or_filter(&value, force_snake_case)
                    && !parsed.is_empty()
                {
                    or_filters.push(parsed);
                }
            }
            other => {
                if let Some(filter) = parse_filter(other, &value, force_snake_case) {
                    filters.push(filter);
                }
            }
        }
    }

    if let Some((start, end)) = parse_range_header(req)
        && end >= start
    {
        offset = Some(start);
        limit = Some(end - start + 1);
    }

    Ok(PostgrestQuery {
        columns,
        filters,
        or_filters,
        limit,
        offset,
        order,
    })
}

fn parse_order(value: &str, force_snake_case: bool) -> Option<OrderSpec> {
    let mut spec = value.to_lowercase();
    let ascending = spec.ends_with(".asc");
    let descending = spec.ends_with(".desc");
    if ascending || descending {
        spec.truncate(spec.len() - 4);
    }
    let column = if force_snake_case {
        normalize_column_name(&spec, true)
    } else {
        spec
    };
    if column.is_empty() {
        return None;
    }
    Some(OrderSpec { column, ascending })
}

fn parse_or_filter(value: &str, force_snake_case: bool) -> Option<Vec<PostgrestFilter>> {
    let trimmed = value.trim();
    let inner = trimmed.trim_start_matches('(').trim_end_matches(')').trim();

    if inner.is_empty() {
        return None;
    }

    let mut filters = Vec::new();
    for expression in inner.split(',') {
        let expression = expression.trim();
        if expression.is_empty() {
            continue;
        }
        if let Some((column, _remainder)) = expression.split_once('.')
            && let Some(filter) = parse_filter(
                column,
                expression
                    .strip_prefix(&format!("{}.", column))
                    .unwrap_or(""),
                force_snake_case,
            )
        {
            filters.push(filter);
        }
    }

    if filters.is_empty() {
        None
    } else {
        Some(filters)
    }
}

fn parse_filter(column: &str, expression: &str, force_snake_case: bool) -> Option<PostgrestFilter> {
    let normalized_column = if force_snake_case {
        normalize_column_name(column, true)
    } else {
        column.to_string()
    };

    if normalized_column.is_empty() {
        return None;
    }

    let (negated, expr) = if let Some(stripped) = expression.strip_prefix("not.") {
        (true, stripped)
    } else {
        (false, expression)
    };

    let (operator_str, value_str) = if let Some((op, rest)) = expr.split_once('.') {
        (op, rest)
    } else {
        return None;
    };

    let operator = match operator_str.to_lowercase().as_str() {
        "eq" => FilterOperator::Eq,
        "neq" => FilterOperator::Neq,
        "gt" => FilterOperator::Gt,
        "lt" => FilterOperator::Lt,
        "gte" => FilterOperator::Gte,
        "lte" => FilterOperator::Lte,
        "like" => FilterOperator::Like,
        "ilike" => FilterOperator::ILike,
        "is" => FilterOperator::Is,
        "in" => FilterOperator::In,
        "cs" => FilterOperator::Contains,
        "cd" => FilterOperator::Contained,
        other => {
            if let Some(stripped) = other.strip_prefix("array_") {
                match stripped {
                    "contains" => FilterOperator::Contains,
                    "contained" => FilterOperator::Contained,
                    _ => FilterOperator::Eq,
                }
            } else {
                FilterOperator::Eq
            }
        }
    };

    let values = match operator {
        FilterOperator::In => parse_in_values(value_str),
        FilterOperator::Contains | FilterOperator::Contained => {
            vec![parse_array_filter(value_str)]
        }
        FilterOperator::Is => vec![parse_scalar_value(value_str)],
        _ => vec![parse_scalar_value(value_str)],
    };

    Some(PostgrestFilter {
        column: normalized_column,
        operator,
        values,
        negated,
    })
}

fn parse_in_values(value: &str) -> Vec<Value> {
    let trimmed = value.trim().trim_start_matches('(').trim_end_matches(')');
    trimmed
        .split(',')
        .map(|part| parse_scalar_value(part.trim()))
        .collect()
}

fn parse_array_filter(value: &str) -> Value {
    let trimmed = value.trim().trim_start_matches('.').trim();
    let inner = trimmed.trim_start_matches('{').trim_end_matches('}').trim();
    let elements: Vec<Value> = inner
        .split(',')
        .map(|part| parse_scalar_value(part.trim()))
        .collect();
    Value::Array(elements)
}

fn parse_scalar_value(value: &str) -> Value {
    let lowered = value.to_lowercase();
    if lowered.is_empty() {
        return Value::String(String::new());
    }

    if lowered == "null" {
        return Value::Null;
    }

    if lowered == "true" {
        return Value::Bool(true);
    }

    if lowered == "false" {
        return Value::Bool(false);
    }

    if let Ok(int_value) = value.parse::<i64>() {
        return Value::Number(int_value.into());
    }

    if let Ok(float_value) = value.parse::<f64>()
        && let Some(number) = serde_json::Number::from_f64(float_value)
    {
        return Value::Number(number);
    }

    Value::String(value.replace('*', "%"))
}

fn parse_range_header(req: &HttpRequest) -> Option<(i64, i64)> {
    let header = req.headers().get("Range")?;
    let header_value = header.to_str().ok()?;
    let cleaned = header_value.trim().trim_start_matches("items=");
    let mut parts = cleaned.split('-');
    let start = parts.next()?.trim().parse::<i64>().ok()?;
    let end = parts.next()?.trim().parse::<i64>().ok()?;
    Some((start, end))
}