tideorm 0.9.14

A developer-friendly ORM for Rust with clean, expressive syntax
Documentation
use super::*;

impl<M: Model> BatchUpdateBuilder<M> {
    pub(crate) fn validate_update_column(column: &str) -> Result<()> {
        let is_safe_identifier = {
            let mut chars = column.chars();
            matches!(chars.next(), Some(ch) if ch == '_' || ch.is_ascii_alphabetic())
                && chars.all(|ch| ch == '_' || ch.is_ascii_alphanumeric())
        };

        if is_safe_identifier && M::column_from_str(column).is_some() {
            Ok(())
        } else {
            Err(Error::invalid_query(format!(
                "unsafe update column '{}': batch updates require a known model field/column name using only ASCII letters, numbers, and underscores",
                column
            )))
        }
    }

    pub(crate) fn quote_update_column(
        column: &str,
        db_type: crate::config::DatabaseType,
    ) -> Result<String> {
        Self::validate_update_column(column)?;
        let canonical_column = M::canonical_column_name(column).unwrap_or(column);
        Ok(Self::quote_identifier(canonical_column, db_type))
    }

    pub(crate) fn quote_identifier(name: &str, db_type: crate::config::DatabaseType) -> String {
        quote_ident(db_type, name)
    }

    pub(crate) fn has_explicit_filters(&self) -> bool {
        !self.conditions.is_empty()
    }

    pub(crate) fn ensure_explicit_filters(&self, operation: &str) -> Result<()> {
        if self.has_explicit_filters() {
            Ok(())
        } else {
            Err(Error::invalid_query(format!(
                "{} requires at least one explicit filter; unfiltered bulk mutations are blocked",
                operation
            )))
        }
    }

    pub(crate) fn validate_json_path(path: &str) -> Result<Vec<&str>> {
        let stripped = path.strip_prefix("$.").ok_or_else(|| {
            Error::invalid_query(format!(
                "unsafe JSON path '{}': only $.field or $.field.subfield paths are supported",
                path
            ))
        })?;

        let segments: Vec<&str> = stripped.split('.').collect();
        if segments.is_empty()
            || segments.iter().any(|segment| {
                segment.is_empty()
                    || !segment
                        .chars()
                        .all(|ch| ch == '_' || ch.is_ascii_alphanumeric())
                    || segment
                        .chars()
                        .next()
                        .map(|ch| ch.is_ascii_digit())
                        .unwrap_or(true)
            })
        {
            return Err(Error::invalid_query(format!(
                "unsafe JSON path '{}': only simple identifier segments are supported",
                path
            )));
        }

        Ok(segments)
    }

    pub(crate) fn postgres_json_path_literal(segments: &[&str]) -> String {
        format!(
            "{{{}}}",
            segments
                .iter()
                .map(|segment| format!("\"{}\"", segment))
                .collect::<Vec<_>>()
                .join(",")
        )
    }

    pub(crate) fn offset_postgres_placeholders(sql: &str, offset: usize) -> String {
        if offset == 0 {
            return sql.to_string();
        }

        #[derive(Clone, Copy)]
        enum ScanState {
            Normal,
            SingleQuoted { backslash_escapes: bool },
            DoubleQuoted,
            LineComment,
            BlockComment,
            DollarQuoted { tag_start: usize, tag_end: usize },
        }

        fn dollar_quote_tag_bounds(chars: &[char], start: usize) -> Option<usize> {
            if chars.get(start) != Some(&'$') {
                return None;
            }

            let mut index = start + 1;
            while index < chars.len() {
                match chars[index] {
                    '$' => return Some(index),
                    ch if ch == '_' || ch.is_ascii_alphanumeric() => index += 1,
                    _ => return None,
                }
            }

            None
        }

        fn has_escape_string_prefix(chars: &[char], quote_index: usize) -> bool {
            if quote_index == 0 {
                return false;
            }

            let prefix = chars[quote_index - 1];
            if prefix != 'e' && prefix != 'E' {
                return false;
            }

            if quote_index == 1 {
                return true;
            }

            !matches!(chars[quote_index - 2], '_' | '$' | 'a'..='z' | 'A'..='Z' | '0'..='9')
        }

        let mut output = String::with_capacity(sql.len());
        let chars: Vec<char> = sql.chars().collect();
        let mut index = 0;
        let mut state = ScanState::Normal;

        while index < chars.len() {
            match state {
                ScanState::Normal => match chars[index] {
                    '\'' => {
                        output.push(chars[index]);
                        state = ScanState::SingleQuoted {
                            backslash_escapes: has_escape_string_prefix(&chars, index),
                        };
                        index += 1;
                    }
                    '"' => {
                        output.push(chars[index]);
                        state = ScanState::DoubleQuoted;
                        index += 1;
                    }
                    '-' if chars.get(index + 1) == Some(&'-') => {
                        output.push(chars[index]);
                        output.push(chars[index + 1]);
                        state = ScanState::LineComment;
                        index += 2;
                    }
                    '/' if chars.get(index + 1) == Some(&'*') => {
                        output.push(chars[index]);
                        output.push(chars[index + 1]);
                        state = ScanState::BlockComment;
                        index += 2;
                    }
                    '$' => {
                        if let Some(tag_end) = dollar_quote_tag_bounds(&chars, index) {
                            if tag_end == index + 1 || !chars[index + 1].is_ascii_digit() {
                                output.extend(chars[index..=tag_end].iter());
                                state = ScanState::DollarQuoted {
                                    tag_start: index,
                                    tag_end,
                                };
                                index = tag_end + 1;
                                continue;
                            }
                        }

                        let start = index + 1;
                        let mut end = start;
                        while end < chars.len() && chars[end].is_ascii_digit() {
                            end += 1;
                        }

                        if end > start {
                            let number: usize = chars[start..end]
                                .iter()
                                .collect::<String>()
                                .parse()
                                .unwrap_or(0);
                            if number > 0 {
                                output.push('$');
                                output.push_str(&(number + offset).to_string());
                                index = end;
                                continue;
                            }
                        }

                        output.push(chars[index]);
                        index += 1;
                    }
                    _ => {
                        output.push(chars[index]);
                        index += 1;
                    }
                },
                ScanState::SingleQuoted { backslash_escapes } => {
                    output.push(chars[index]);
                    if backslash_escapes && chars[index] == '\\' {
                        if let Some(next) = chars.get(index + 1) {
                            output.push(*next);
                            index += 2;
                            continue;
                        }
                    }
                    if chars[index] == '\'' {
                        if chars.get(index + 1) == Some(&'\'') {
                            output.push(chars[index + 1]);
                            index += 2;
                            continue;
                        }
                        state = ScanState::Normal;
                    }
                    index += 1;
                }
                ScanState::DoubleQuoted => {
                    output.push(chars[index]);
                    if chars[index] == '"' {
                        if chars.get(index + 1) == Some(&'"') {
                            output.push(chars[index + 1]);
                            index += 2;
                            continue;
                        }
                        state = ScanState::Normal;
                    }
                    index += 1;
                }
                ScanState::LineComment => {
                    output.push(chars[index]);
                    if chars[index] == '\n' {
                        state = ScanState::Normal;
                    }
                    index += 1;
                }
                ScanState::BlockComment => {
                    output.push(chars[index]);
                    if chars[index] == '*' && chars.get(index + 1) == Some(&'/') {
                        output.push(chars[index + 1]);
                        state = ScanState::Normal;
                        index += 2;
                        continue;
                    }
                    index += 1;
                }
                ScanState::DollarQuoted { tag_start, tag_end } => {
                    let tag_len = tag_end - tag_start + 1;
                    if chars[index] == '$'
                        && chars.get(index..index + tag_len) == Some(&chars[tag_start..=tag_end])
                    {
                        output.extend(chars[index..index + tag_len].iter());
                        state = ScanState::Normal;
                        index += tag_len;
                        continue;
                    }

                    output.push(chars[index]);
                    index += 1;
                }
            }
        }

        output
    }
}