athena_rs 3.4.7

Database driver
Documentation
use anyhow::Result;
use serde_json::Value;
use std::collections::HashMap;
use std::str::Chars;
use uuid::Uuid;

use crate::utils::postgres_types::{jsonb_cast_for_column, timestamptz_cast_for_column};

pub mod __tests__;
pub mod condition;
pub mod operator;

pub use condition::{Condition, ConditionOperator};

use crate::parser::query_builder::operator::format_condition_clause;

/// Stateful query builder for SQL helper clauses used by gateway drivers.
#[derive(Debug, Clone)]
pub struct QueryBuilder {
    conditions: Vec<Condition>,
    start_index: usize,
}

impl Default for QueryBuilder {
    fn default() -> Self {
        Self {
            conditions: Vec::new(),
            start_index: 1,
        }
    }
}

impl QueryBuilder {
    /// Creates a new `QueryBuilder` with no conditions and start index `1`.
    pub fn new() -> Self {
        Self::default()
    }

    /// Creates a new `QueryBuilder` from pre-built conditions.
    pub fn from_conditions(conditions: impl IntoIterator<Item = Condition>) -> Self {
        Self::new().with_conditions(conditions)
    }

    /// Sets the placeholder start index used when rendering conditions.
    pub fn with_start_index(mut self, start_index: usize) -> Self {
        self.start_index = start_index;
        self
    }

    /// Appends a single condition to this builder.
    pub fn with_condition(mut self, condition: Condition) -> Self {
        self.conditions.push(condition);
        self
    }

    /// Appends many conditions to this builder.
    pub fn with_conditions(mut self, conditions: impl IntoIterator<Item = Condition>) -> Self {
        self.conditions.extend(conditions);
        self
    }

    /// Pushes a single condition in-place.
    pub fn push_condition(&mut self, condition: Condition) -> &mut Self {
        self.conditions.push(condition);
        self
    }

    /// Builds a WHERE clause using this builder's conditions and start index.
    pub fn build_where_clause(&self) -> Result<(String, Vec<Value>)> {
        Self::build_where_clause_from(&self.conditions, self.start_index)
    }

    /// Builds a WHERE clause for arbitrary conditions and start index.
    pub fn build_where_clause_from(
        conditions: &[Condition],
        start_index: usize,
    ) -> Result<(String, Vec<Value>)> {
        let mut clause_parts: Vec<String> = Vec::new();
        let mut values: Vec<Value> = Vec::new();
        let mut idx: usize = start_index;

        for condition in conditions {
            if let Some(column) = sanitize_identifier(&condition.column)
                && let Some(single_clause) =
                    format_condition_clause(&column, condition, &mut idx, &mut values)
            {
                clause_parts.push(single_clause);
            }
        }

        let clause: String = if clause_parts.is_empty() {
            String::new()
        } else {
            format!(" WHERE {}", clause_parts.join(" AND "))
        };

        Ok((clause, values))
    }

    /// Like [`Self::build_insert_placeholders`], but column-aware: emits
    /// `$n::timestamptz` for timestamp-shaped JSON strings when
    /// [`timestamptz_cast_for_column`] applies.
    pub fn build_insert_placeholders_for_entries<'a>(
        entries: &'a [(String, Value)],
    ) -> (Vec<String>, Vec<&'a Value>) {
        Self::build_insert_placeholders_for_entries_with_types(entries, None)
    }

    /// Like [`Self::build_insert_placeholders_for_entries`], but uses optional
    /// column type metadata to emit `$n::jsonb` for JSONB columns.
    pub fn build_insert_placeholders_for_entries_with_types<'a>(
        entries: &'a [(String, Value)],
        column_types: Option<&HashMap<String, String>>,
    ) -> (Vec<String>, Vec<&'a Value>) {
        let mut placeholders: Vec<String> = Vec::with_capacity(entries.len());
        let mut bind_values: Vec<&Value> = Vec::new();
        let mut next_param_index: i32 = 1;

        for (column, value) in entries {
            if value.is_null() {
                placeholders.push("NULL".to_string());
                continue;
            }

            let placeholder: String = if jsonb_cast_for_column(column, column_types) {
                format!("${}::jsonb", next_param_index)
            } else if timestamptz_cast_for_column(column, value) {
                format!("${}::timestamptz", next_param_index)
            } else {
                format!("${}", next_param_index)
            };
            placeholders.push(placeholder);
            bind_values.push(value);
            next_param_index += 1;
        }

        (placeholders, bind_values)
    }

    /// Determines placeholder and bind values for inserts, treating JSON null
    /// as SQL `NULL`.
    pub fn build_insert_placeholders<'a>(values: &[&'a Value]) -> (Vec<String>, Vec<&'a Value>) {
        let mut placeholders: Vec<String> = Vec::with_capacity(values.len());
        let mut bind_values: Vec<&Value> = Vec::new();
        let mut next_param_index: i32 = 1;

        for value in values {
            if value.is_null() {
                placeholders.push("NULL".to_string());
            } else {
                placeholders.push(format!("${}", next_param_index));
                bind_values.push(value);
                next_param_index += 1;
            }
        }

        (placeholders, bind_values)
    }

    /// Quotes identifiers to keep SQL generation safe.
    pub fn sanitize_identifier(identifier: &str) -> Option<String> {
        sanitize_identifier(identifier)
    }

    /// Quotes and validates table identifiers in either `table` or
    /// `schema.table` form.
    pub fn sanitize_qualified_table_identifier(table_name: &str) -> Option<String> {
        sanitize_qualified_table_identifier(table_name)
    }
}

/// Compatibility free function that delegates to [`QueryBuilder`].
pub fn build_insert_placeholders_for_entries<'a>(
    entries: &'a [(String, Value)],
) -> (Vec<String>, Vec<&'a Value>) {
    QueryBuilder::build_insert_placeholders_for_entries(entries)
}

/// Compatibility free function that delegates to [`QueryBuilder`].
pub fn build_insert_placeholders_for_entries_with_types<'a>(
    entries: &'a [(String, Value)],
    column_types: Option<&HashMap<String, String>>,
) -> (Vec<String>, Vec<&'a Value>) {
    QueryBuilder::build_insert_placeholders_for_entries_with_types(entries, column_types)
}

/// Compatibility free function that delegates to [`QueryBuilder`].
pub fn build_insert_placeholders<'a>(values: &[&'a Value]) -> (Vec<String>, Vec<&'a Value>) {
    QueryBuilder::build_insert_placeholders(values)
}

/// Quotes identifiers to keep SQL generation safe.
pub fn sanitize_identifier(identifier: &str) -> Option<String> {
    let mut chars: Chars<'_> = identifier.chars();
    let first: char = chars.next()?;
    if !(first.is_ascii_alphabetic() || first == '_') {
        return None;
    }
    if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
        return None;
    }
    // Quote the identifier to handle reserved keywords and mixed case
    Some(format!("\"{}\"", identifier))
}

/// Quotes and validates table identifiers in either `table` or
/// `schema.table` form.
/// Gateway `/gateway/insert` and fetch use this so `table_name` values like
/// `public.query_optimization_runs` map to `"public"."query_optimization_runs"`.
pub fn sanitize_qualified_table_identifier(table_name: &str) -> Option<String> {
    let mut parts: Vec<String> = Vec::new();
    for segment in table_name.split('.') {
        let trimmed = segment.trim();
        if trimmed.is_empty() {
            return None;
        }
        if let Some(sanitized) = sanitize_identifier(trimmed) {
            parts.push(sanitized);
        } else {
            return None;
        }
    }
    if parts.is_empty() {
        return None;
    }
    Some(parts.join("."))
}

/// Compatibility free function that delegates to [`QueryBuilder`].
pub fn build_where_clause(
    conditions: &[Condition],
    start_index: usize,
) -> Result<(String, Vec<Value>)> {
    QueryBuilder::build_where_clause_from(conditions, start_index)
}

fn create_placeholder_clause(
    column: &str,
    idx: &mut usize,
    values: &mut Vec<Value>,
    condition: &Condition,
    value: &Value,
    comparator: &str,
) -> Option<String> {
    let cast_uuid_to_text: bool = should_cast_uuid_value_to_text(condition, value);
    let lhs: String = if cast_uuid_to_text {
        format!("t.{}::text", column)
    } else {
        format!("t.{}", column)
    };
    let placeholder = format!("${}", idx);
    values.push(value.clone());
    *idx += 1;
    let rhs = if cast_uuid_to_text {
        format!("{placeholder}::text")
    } else {
        placeholder
    };
    Some(format!("{} {} {}", lhs, comparator, rhs))
}

fn should_cast_uuid_value_to_text(condition: &Condition, value: &Value) -> bool {
    condition.auto_cast_uuid_value_to_text
        && value
            .as_str()
            .map(|text| Uuid::parse_str(text).is_ok())
            .unwrap_or(false)
}

fn build_is_clause(
    column: &str,
    condition: &Condition,
    idx: &mut usize,
    values: &mut Vec<Value>,
) -> Option<String> {
    match condition.values.first() {
        Some(Value::Null) => Some(format!("t.{} IS NULL", column)),
        Some(Value::Bool(true)) => Some(format!("t.{} IS TRUE", column)),
        Some(Value::Bool(false)) => Some(format!("t.{} IS FALSE", column)),
        Some(other) => {
            let placeholder = format!("${}", idx);
            values.push(other.clone());
            *idx += 1;
            Some(format!("t.{} IS {}", column, placeholder))
        }
        None => Some(format!("t.{} IS NULL", column)),
    }
}

fn build_in_clause(
    column: &str,
    condition: &Condition,
    idx: &mut usize,
    values: &mut Vec<Value>,
) -> Option<String> {
    if condition.values.is_empty() {
        return None;
    }

    let mut placeholders: Vec<String> = Vec::new();
    for value in &condition.values {
        placeholders.push(format!("${}", idx));
        values.push(value.clone());
        *idx += 1;
    }
    Some(format!("t.{} IN ({})", column, placeholders.join(", ")))
}

fn build_array_clause(
    column: &str,
    condition: &Condition,
    idx: &mut usize,
    values: &mut Vec<Value>,
    operator: &str,
) -> Option<String> {
    let array_value: &Value = condition.values.first()?;
    let placeholder: String = format!("${}", idx);
    values.push(array_value.clone());
    *idx += 1;
    Some(format!("t.{} {} {}", column, operator, placeholder))
}