athena_rs 3.3.0

Database gateway API
Documentation
use anyhow::Result;
use serde_json::Value;
use std::str::Chars;
use uuid::Uuid;

use crate::utils::postgres_types::timestamptz_cast_for_column;

/// Like [`build_insert_placeholders`], but column-aware: emits `$n::timestamptz` for
/// timestamp-shaped JSON strings when [`timestamptz_cast_for_column`] applies.
pub fn build_insert_placeholders_for_entries<'a>(
    entries: &'a [(String, Value)],
) -> (Vec<String>, Vec<&'a Value>) {
    let mut placeholders: Vec<String> = Vec::with_capacity(entries.len());
    let mut bind_values: Vec<&Value> = Vec::new();
    let mut next_param_index: i32 = 1;

    for (column, value) in entries {
        if value.is_null() {
            placeholders.push("NULL".to_string());
            continue;
        }

        let placeholder = if timestamptz_cast_for_column(column, value) {
            format!("${}::timestamptz", next_param_index)
        } else {
            format!("${}", next_param_index)
        };
        placeholders.push(placeholder);
        bind_values.push(value);
        next_param_index += 1;
    }

    (placeholders, bind_values)
}

/// Determines the placeholder and bind values for inserts, treating JSON null as SQL NULL.
pub fn build_insert_placeholders<'a>(values: &[&'a Value]) -> (Vec<String>, Vec<&'a Value>) {
    let mut placeholders: Vec<String> = Vec::with_capacity(values.len());
    let mut bind_values: Vec<&Value> = Vec::new();
    let mut next_param_index: i32 = 1;

    for value in values {
        if value.is_null() {
            placeholders.push("NULL".to_string());
        } else {
            placeholders.push(format!("${}", next_param_index));
            bind_values.push(value);
            next_param_index += 1;
        }
    }

    (placeholders, bind_values)
}

/// Supported comparison operators used to describe a condition.
#[derive(Debug, Clone, Copy)]
pub enum ConditionOperator {
    Eq,
    Neq,
    Gt,
    Lt,
    In,
    Gte,
    Lte,
    Like,
    ILike,
    Is,
    Contains,
    Contained,
}

/// Represents a filter condition used for building SQL queries.
#[derive(Debug)]
pub struct Condition {
    pub column: String,
    pub operator: ConditionOperator,
    pub values: Vec<Value>,
    pub negated: bool,
    pub auto_cast_uuid_value_to_text: bool,
}

impl Condition {
    /// Builds an equality condition for the given column and value.
    pub fn eq(column: impl Into<String>, value: impl Into<Value>) -> Self {
        Self {
            column: column.into(),
            operator: ConditionOperator::Eq,
            values: vec![value.into()],
            negated: false,
            auto_cast_uuid_value_to_text: true,
        }
    }

    /// Builds a condition using the provided values and operator.
    pub fn new(
        column: impl Into<String>,
        operator: ConditionOperator,
        values: Vec<Value>,
        negated: bool,
    ) -> Self {
        Self {
            column: column.into(),
            operator,
            values,
            negated,
            auto_cast_uuid_value_to_text: true,
        }
    }

    pub fn with_uuid_value_text_cast(mut self, enabled: bool) -> Self {
        self.auto_cast_uuid_value_to_text = enabled;
        self
    }
}

/// Quotes identifiers to keep SQL generation safe.
pub fn sanitize_identifier(identifier: &str) -> Option<String> {
    let mut chars: Chars<'_> = identifier.chars();
    let first: char = chars.next()?;
    if !(first.is_ascii_alphabetic() || first == '_') {
        return None;
    }
    if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
        return None;
    }
    // Quote the identifier to handle reserved keywords and mixed case
    Some(format!("\"{}\"", identifier))
}

/// Quotes and validates table identifiers in either `table` or `schema.table` form.
/// Gateway `/gateway/insert` and fetch use this so `table_name` values like
/// `public.query_optimization_runs` map to `"public"."query_optimization_runs"`.
pub fn sanitize_qualified_table_identifier(table_name: &str) -> Option<String> {
    let mut parts: Vec<String> = Vec::new();
    for segment in table_name.split('.') {
        let trimmed = segment.trim();
        if trimmed.is_empty() {
            return None;
        }
        if let Some(sanitized) = sanitize_identifier(trimmed) {
            parts.push(sanitized);
        } else {
            return None;
        }
    }
    if parts.is_empty() {
        return None;
    }
    Some(parts.join("."))
}

/// Builds a WHERE clause for a list of filter conditions.
pub fn build_where_clause(
    conditions: &[Condition],
    start_index: usize,
) -> Result<(String, Vec<Value>)> {
    let mut clause_parts: Vec<String> = Vec::new();
    let mut values: Vec<Value> = Vec::new();
    let mut idx: usize = start_index;

    for condition in conditions {
        if let Some(column) = sanitize_identifier(&condition.column)
            && let Some(single_clause) =
                format_condition_clause(&column, condition, &mut idx, &mut values)
        {
            clause_parts.push(single_clause);
        }
    }

    let clause = if clause_parts.is_empty() {
        String::new()
    } else {
        format!(" WHERE {}", clause_parts.join(" AND "))
    };

    Ok((clause, values))
}

pub(crate) fn format_condition_clause(
    column: &str,
    condition: &Condition,
    idx: &mut usize,
    values: &mut Vec<Value>,
) -> Option<String> {
    let clause = match condition.operator {
        ConditionOperator::Eq => create_placeholder_clause(
            column,
            idx,
            values,
            condition,
            condition.values.first()?,
            "=",
        )?,
        ConditionOperator::Neq => create_placeholder_clause(
            column,
            idx,
            values,
            condition,
            condition.values.first()?,
            "<>",
        )?,
        ConditionOperator::Gt => create_placeholder_clause(
            column,
            idx,
            values,
            condition,
            condition.values.first()?,
            ">",
        )?,
        ConditionOperator::Gte => create_placeholder_clause(
            column,
            idx,
            values,
            condition,
            condition.values.first()?,
            ">=",
        )?,
        ConditionOperator::Lt => create_placeholder_clause(
            column,
            idx,
            values,
            condition,
            condition.values.first()?,
            "<",
        )?,
        ConditionOperator::Lte => create_placeholder_clause(
            column,
            idx,
            values,
            condition,
            condition.values.first()?,
            "<=",
        )?,
        ConditionOperator::Like => create_placeholder_clause(
            column,
            idx,
            values,
            condition,
            condition.values.first()?,
            "LIKE",
        )?,
        ConditionOperator::ILike => create_placeholder_clause(
            column,
            idx,
            values,
            condition,
            condition.values.first()?,
            "ILIKE",
        )?,
        ConditionOperator::Is => build_is_clause(column, condition, idx, values)?,
        ConditionOperator::In => build_in_clause(column, condition, idx, values)?,
        ConditionOperator::Contains => build_array_clause(column, condition, idx, values, "@>")?,
        ConditionOperator::Contained => build_array_clause(column, condition, idx, values, "<@")?,
    };

    let clause = if condition.negated {
        format!("NOT ({})", clause)
    } else {
        clause
    };

    Some(clause)
}

fn create_placeholder_clause(
    column: &str,
    idx: &mut usize,
    values: &mut Vec<Value>,
    condition: &Condition,
    value: &Value,
    comparator: &str,
) -> Option<String> {
    let cast_uuid_to_text = should_cast_uuid_value_to_text(condition, value);
    let lhs = if cast_uuid_to_text {
        format!("t.{}::text", column)
    } else {
        format!("t.{}", column)
    };
    let placeholder = format!("${}", idx);
    values.push(value.clone());
    *idx += 1;
    let rhs = if cast_uuid_to_text {
        format!("{placeholder}::text")
    } else {
        placeholder
    };
    Some(format!("{} {} {}", lhs, comparator, rhs))
}

fn should_cast_uuid_value_to_text(condition: &Condition, value: &Value) -> bool {
    condition.auto_cast_uuid_value_to_text
        && value
            .as_str()
            .map(|text| Uuid::parse_str(text).is_ok())
            .unwrap_or(false)
}

fn build_is_clause(
    column: &str,
    condition: &Condition,
    idx: &mut usize,
    values: &mut Vec<Value>,
) -> Option<String> {
    match condition.values.first() {
        Some(Value::Null) => Some(format!("t.{} IS NULL", column)),
        Some(Value::Bool(true)) => Some(format!("t.{} IS TRUE", column)),
        Some(Value::Bool(false)) => Some(format!("t.{} IS FALSE", column)),
        Some(other) => {
            let placeholder = format!("${}", idx);
            values.push(other.clone());
            *idx += 1;
            Some(format!("t.{} IS {}", column, placeholder))
        }
        None => Some(format!("t.{} IS NULL", column)),
    }
}

fn build_in_clause(
    column: &str,
    condition: &Condition,
    idx: &mut usize,
    values: &mut Vec<Value>,
) -> Option<String> {
    if condition.values.is_empty() {
        return None;
    }

    let mut placeholders: Vec<String> = Vec::new();
    for value in &condition.values {
        placeholders.push(format!("${}", idx));
        values.push(value.clone());
        *idx += 1;
    }
    Some(format!("t.{} IN ({})", column, placeholders.join(", ")))
}

fn build_array_clause(
    column: &str,
    condition: &Condition,
    idx: &mut usize,
    values: &mut Vec<Value>,
    operator: &str,
) -> Option<String> {
    let array_value = condition.values.first()?;
    let placeholder = format!("${}", idx);
    values.push(array_value.clone());
    *idx += 1;
    Some(format!("t.{} {} {}", column, operator, placeholder))
}

#[cfg(test)]
mod tests {
    use super::{Condition, build_where_clause};
    use serde_json::json;

    #[test]
    fn build_where_clause_casts_uuid_comparisons_to_text_by_default() {
        let conditions = vec![Condition::eq(
            "workspace_id",
            json!("550e8400-e29b-41d4-a716-446655440000"),
        )];

        let (clause, values) = build_where_clause(&conditions, 1).expect("where clause");

        assert_eq!(clause, " WHERE t.\"workspace_id\"::text = $1::text");
        assert_eq!(values, vec![json!("550e8400-e29b-41d4-a716-446655440000")]);
    }

    #[test]
    fn build_where_clause_can_disable_uuid_text_casts() {
        let conditions = vec![
            Condition::eq(
                "workspace_id",
                json!("550e8400-e29b-41d4-a716-446655440000"),
            )
            .with_uuid_value_text_cast(false),
        ];

        let (clause, values) = build_where_clause(&conditions, 1).expect("where clause");

        assert_eq!(clause, " WHERE t.\"workspace_id\" = $1");
        assert_eq!(values, vec![json!("550e8400-e29b-41d4-a716-446655440000")]);
    }
}