es-entity-macros 0.5.2

Proc macros for es-entity
Documentation
use proc_macro2::Span;
use syn::{
    parse::{Parse, ParseStream},
    punctuated::Punctuated,
};

pub struct QueryInput {
    pub(super) ignore_prefix: Option<String>,
    pub(super) executor: syn::Expr,
    pub(super) sql: String,
    pub(super) sql_span: Span,
    pub(super) arg_exprs: Vec<syn::Expr>,
}

impl QueryInput {
    pub(super) fn table_name(&self) -> darling::Result<String> {
        let query = self.sql.to_lowercase();
        let words: Vec<&str> = query.split_whitespace().collect();
        let from_pos = words.iter().position(|&word| word == "from").ok_or(
            darling::Error::custom("Could not identify table name - no 'FROM' clause")
                .with_span(&self.sql_span),
        )?;
        let table_name = words.get(from_pos + 1).ok_or(
            darling::Error::custom("No word after 'FROM' clause").with_span(&self.sql_span),
        )?;
        let table_name = table_name.trim_end_matches(|c: char| !c.is_alphanumeric());
        Ok(table_name.to_string())
    }

    pub(super) fn table_name_without_prefix(&self) -> darling::Result<String> {
        let table_name = self.table_name()?;
        if let Some(ignore_prefix) = &self.ignore_prefix {
            if table_name.starts_with(ignore_prefix) {
                return Ok(table_name[ignore_prefix.len() + 1..].to_string());
            }
        }
        Ok(table_name)
    }

    pub(super) fn order_by(&self) -> String {
        let columns = self.order_by_columns();
        if columns.is_empty() {
            "i.id,".to_string()
        } else {
            columns.join(", ") + ", i.id,"
        }
    }

    fn order_by_columns(&self) -> Vec<String> {
        use regex::Regex;
        let re = Regex::new(r"(?i)ORDER\s+BY\s+(.+?)(?:\s+(?:LIMIT|OFFSET)|\s*;?\s*$)").unwrap();

        if let Some(captures) = re.captures(&self.sql.to_lowercase()) {
            if let Some(order_by_clause) = captures.get(1) {
                return order_by_clause
                    .as_str()
                    .split(',')
                    .map(|s| format!("i.{}", s.trim()))
                    .filter(|s| !s.is_empty())
                    .collect();
            }
        }

        Vec::new()
    }
}

impl Parse for QueryInput {
    fn parse(input: ParseStream) -> syn::Result<Self> {
        let mut sql: Option<(String, Span)> = None;
        let mut args: Option<Vec<syn::Expr>> = None;
        let mut executor: Option<syn::Expr> = None;
        let mut expect_comma = false;
        let mut ignore_prefix = None;

        while !input.is_empty() {
            if expect_comma {
                let _ = input.parse::<syn::token::Comma>()?;
            }
            let key: syn::Ident = input.parse()?;

            let _ = input.parse::<syn::token::Eq>()?;

            if key == "executor" {
                executor = Some(input.parse::<syn::Expr>()?);
            } else if key == "ignore_prefix" {
                ignore_prefix = Some(input.parse::<syn::LitStr>()?.value());
            } else if key == "sql" {
                sql = Some((
                    Punctuated::<syn::LitStr, syn::Token![+]>::parse_separated_nonempty(input)?
                        .iter()
                        .map(syn::LitStr::value)
                        .collect(),
                    input.span(),
                ));
            } else if key == "args" {
                let exprs = input.parse::<syn::ExprArray>()?;
                args = Some(exprs.elems.into_iter().collect())
            } else {
                let message = format!("unexpected input key: {key}");
                return Err(syn::Error::new_spanned(key, message));
            }

            expect_comma = true;
        }

        let (sql, sql_span) = sql.ok_or_else(|| input.error("expected `sql` key"))?;
        let executor = executor.ok_or_else(|| input.error("expected `executor` key"))?;

        Ok(QueryInput {
            ignore_prefix,
            executor,
            sql,
            sql_span,
            arg_exprs: args.unwrap_or_default(),
        })
    }
}

#[cfg(test)]
mod tests {
    use syn::parse_quote;

    use super::*;

    #[test]
    fn parse_input() {
        let input: QueryInput = parse_quote!(
            ignore_prefix = "ignore_prefix",
            executor = &mut **tx,
            sql = "SELECT * FROM ignore_prefix_users WHERE name = $1",
            args = [id]
        );
        assert_eq!(input.ignore_prefix, Some("ignore_prefix".to_string()));
        assert_eq!(
            input.sql,
            "SELECT * FROM ignore_prefix_users WHERE name = $1"
        );
        assert_eq!(input.executor, parse_quote!(&mut **tx));
        assert_eq!(input.arg_exprs[0], parse_quote!(id));
        assert_eq!(input.table_name_without_prefix().unwrap(), "users");
    }

    #[test]
    fn test_order_by_columns() {
        let test_cases = vec![
            (
                "SELECT id FROM entities WHERE (id > $2) OR $2 IS NULL ORDER BY id LIMIT $1",
                vec!["i.id"],
            ),
            (
                "select id from entities order by name asc, date desc",
                vec!["i.name asc", "i.date desc"],
            ),
            ("SELECT TOP 10 id FROM entities Order By id", vec!["i.id"]),
            (
                "select id from entities ORDER BY id offset 10",
                vec!["i.id"],
            ),
            ("SELECT id FROM entities orDer bY id;", vec!["i.id"]),
            (
                "SELECT * FROM users WHERE age > 18 ORDER BY last_name, first_name DESC LIMIT 10",
                vec!["i.last_name", "i.first_name desc"],
            ),
            (
                "SELECT * FROM products ORDER BY price ASC, stock DESC, name",
                vec!["i.price asc", "i.stock desc", "i.name"],
            ),
            ("SELECT * FROM orders", vec![]),
            (
                "SELECT * FROM orders ORDER BY orders NULLS FIRST, id",
                vec!["i.orders nulls first", "i.id"],
            ),
        ];

        for (sql, expected) in test_cases {
            let input = QueryInput {
                ignore_prefix: None,
                executor: parse_quote!(&mut **tx),
                sql: sql.to_string(),
                sql_span: Span::call_site(),
                arg_exprs: vec![],
            };
            assert_eq!(
                input.order_by_columns(),
                expected,
                "Failed for SQL: {}",
                sql
            );
        }
    }
}