use crate::config::SourceType;
pub(crate) fn quote_ident(source_type: SourceType, name: &str) -> String {
match source_type {
SourceType::Postgres => format!("\"{}\"", name.replace('"', "\"\"")),
SourceType::Mysql => format!("`{}`", name.replace('`', "``")),
SourceType::Mssql => format!("[{}]", name.replace(']', "]]")),
}
}
pub(crate) fn strip_select_star_from(base_query: &str) -> Option<&str> {
let trimmed = base_query.trim();
let after = strip_prefix_ascii_ci(trimmed, "select")
.map(str::trim_start)
.and_then(|s| s.strip_prefix('*'))
.map(str::trim_start)
.and_then(|s| strip_prefix_ascii_ci(s, "from"))?;
parse_bare_table_ident(after.trim_start())
}
pub(crate) fn strip_simple_projection_from(base_query: &str) -> Option<&str> {
let trimmed = base_query.trim();
let after_select = strip_prefix_ascii_ci(trimmed, "select").map(str::trim_start)?;
let from_at = find_from_keyword(after_select)?;
let projection = after_select[..from_at].trim();
if !is_plain_column_list(projection) {
return None;
}
parse_bare_table_ident(after_select[from_at + "from".len()..].trim_start())
}
fn find_from_keyword(s: &str) -> Option<usize> {
let bytes = s.as_bytes();
let mut i = 0;
while i + 4 <= bytes.len() {
if s[i..i + 4].eq_ignore_ascii_case("from")
&& (i == 0 || bytes[i - 1].is_ascii_whitespace())
&& (i + 4 == bytes.len() || bytes[i + 4].is_ascii_whitespace())
{
return Some(i);
}
i += 1;
}
None
}
fn is_plain_column_list(s: &str) -> bool {
if s.is_empty() {
return false;
}
if let Some(rest) = strip_prefix_ascii_ci(s, "distinct")
&& rest.starts_with(|c: char| c.is_whitespace())
{
return false;
}
s.chars().all(|c| {
c.is_ascii_alphanumeric() || matches!(c, '_' | '.' | ',' | '*' | ' ' | '\t' | '\n' | '\r')
})
}
fn parse_bare_table_ident(after_from: &str) -> Option<&str> {
let rest = after_from.trim_start();
let end = rest
.find(|c: char| !(c.is_ascii_alphanumeric() || c == '_' || c == '.'))
.unwrap_or(rest.len());
let ident = &rest[..end];
let parts: Vec<&str> = ident.split('.').collect();
if !(1..=2).contains(&parts.len()) {
return None;
}
for p in &parts {
let mut chars = p.chars();
match chars.next() {
Some(c) if c.is_ascii_alphabetic() || c == '_' => {
if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
return None;
}
}
_ => return None,
}
}
if !rest[end..].trim().is_empty() {
return None;
}
Some(ident)
}
fn strip_prefix_ascii_ci<'a>(s: &'a str, prefix: &str) -> Option<&'a str> {
if s.len() >= prefix.len() && s[..prefix.len()].eq_ignore_ascii_case(prefix) {
Some(&s[prefix.len()..])
} else {
None
}
}
pub(crate) fn aggregate_sql(
source_type: SourceType,
agg: &str,
col: &str,
base_query: &str,
) -> String {
let q = quote_ident(source_type, col);
match strip_simple_projection_from(base_query) {
Some(table_ident) => format!("SELECT {agg}({q}) FROM {table_ident}"),
None => format!("SELECT {agg}({q}) FROM ({base_query}) AS _rivet"),
}
}
pub(crate) fn null_key_probe_sql(source_type: SourceType, col: &str, base_query: &str) -> String {
let from = match strip_simple_projection_from(base_query) {
Some(table_ident) => table_ident.to_string(),
None => format!("({base_query}) AS _rivet_nullprobe"),
};
let q = quote_ident(source_type, col);
match source_type {
SourceType::Mssql => format!("SELECT TOP 1 1 FROM {from} WHERE {q} IS NULL"),
SourceType::Postgres | SourceType::Mysql => {
format!("SELECT 1 FROM {from} WHERE {q} IS NULL LIMIT 1")
}
}
}
pub(crate) fn row_estimate_sql(source_type: SourceType, table_ident: &str) -> Option<String> {
match source_type {
SourceType::Postgres => Some(format!(
"SELECT GREATEST(reltuples, 0)::bigint FROM pg_class WHERE oid = '{table_ident}'::regclass"
)),
SourceType::Mysql => None,
SourceType::Mssql => Some(format!(
"SELECT SUM(p.row_count) FROM sys.dm_db_partition_stats p \
WHERE p.object_id = OBJECT_ID('{table_ident}') AND p.index_id IN (0,1)"
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn postgres_plain_identifier() {
assert_eq!(quote_ident(SourceType::Postgres, "id"), "\"id\"");
assert_eq!(
quote_ident(SourceType::Postgres, "created_at"),
"\"created_at\""
);
}
#[test]
fn postgres_escapes_internal_double_quotes() {
assert_eq!(
quote_ident(SourceType::Postgres, "col\"name"),
"\"col\"\"name\""
);
}
#[test]
fn mysql_plain_identifier() {
assert_eq!(quote_ident(SourceType::Mysql, "id"), "`id`");
assert_eq!(quote_ident(SourceType::Mysql, "created_at"), "`created_at`");
}
#[test]
fn mysql_escapes_internal_backticks() {
assert_eq!(quote_ident(SourceType::Mysql, "col`name"), "`col``name`");
}
#[test]
fn strip_select_star_from_accepts_simple_table_forms() {
assert_eq!(
strip_select_star_from("SELECT * FROM events"),
Some("events")
);
assert_eq!(
strip_select_star_from("select * from public.orders"),
Some("public.orders")
);
}
#[test]
fn strip_select_star_from_rejects_anything_with_a_clause() {
assert!(strip_select_star_from("SELECT * FROM events WHERE id > 1").is_none());
assert!(strip_select_star_from("SELECT id FROM events").is_none());
assert!(strip_select_star_from("SELECT * FROM a JOIN b").is_none());
assert!(strip_select_star_from("SELECT * FROM events;").is_none());
assert!(strip_select_star_from("SELECT * FROM a.b.c").is_none());
}
#[test]
fn strip_simple_projection_accepts_plain_column_lists() {
assert_eq!(
strip_simple_projection_from("SELECT id, title, body FROM content_items"),
Some("content_items")
);
assert_eq!(
strip_simple_projection_from("SELECT * FROM events"),
Some("events")
);
assert_eq!(
strip_simple_projection_from("select a.x, a.y from public.users"),
Some("public.users")
);
assert_eq!(
strip_simple_projection_from("SELECT id, a, b, c, d FROM content_items\n"),
Some("content_items")
);
}
#[test]
fn strip_simple_projection_rejects_anything_that_changes_the_row_set() {
assert!(strip_simple_projection_from("SELECT id FROM t WHERE id > 1").is_none());
assert!(
strip_simple_projection_from("SELECT a.id FROM t a JOIN u ON a.id = u.id").is_none()
);
assert!(strip_simple_projection_from("SELECT id FROM t GROUP BY id").is_none());
assert!(strip_simple_projection_from("SELECT id FROM t;").is_none());
assert!(strip_simple_projection_from("SELECT DISTINCT id FROM t").is_none());
assert!(strip_simple_projection_from("SELECT count(*) FROM t").is_none());
assert!(strip_simple_projection_from("SELECT lower(name) FROM t").is_none());
assert!(strip_simple_projection_from("SELECT id FROM (SELECT id FROM y) z").is_none());
assert!(strip_simple_projection_from("SELECT 'from x' FROM t").is_none());
assert!(strip_simple_projection_from("SELECT id FROM a.b.c").is_none());
}
#[test]
fn aggregate_sql_fast_path_on_table_shortcut() {
assert_eq!(
aggregate_sql(
SourceType::Postgres,
"min",
"created_at",
"SELECT * FROM events"
),
"SELECT min(\"created_at\") FROM events"
);
}
#[test]
fn aggregate_sql_wraps_a_real_query() {
assert_eq!(
aggregate_sql(
SourceType::Postgres,
"max",
"created_at",
"SELECT id, created_at FROM events WHERE x"
),
"SELECT max(\"created_at\") FROM (SELECT id, created_at FROM events WHERE x) AS _rivet"
);
assert!(
aggregate_sql(SourceType::Mysql, "min", "d", "SELECT d FROM t WHERE 1")
.contains("min(`d`)")
);
}
#[test]
fn null_key_probe_sql_is_a_presence_probe_not_a_count() {
assert_eq!(
null_key_probe_sql(SourceType::Postgres, "id", "SELECT id, title FROM orders"),
"SELECT 1 FROM orders WHERE \"id\" IS NULL LIMIT 1"
);
assert_eq!(
null_key_probe_sql(SourceType::Mssql, "id", "SELECT * FROM orders"),
"SELECT TOP 1 1 FROM orders WHERE [id] IS NULL"
);
assert_eq!(
null_key_probe_sql(
SourceType::Postgres,
"id",
"SELECT id FROM orders WHERE x > 1"
),
"SELECT 1 FROM (SELECT id FROM orders WHERE x > 1) AS _rivet_nullprobe \
WHERE \"id\" IS NULL LIMIT 1"
);
assert!(!null_key_probe_sql(SourceType::Mysql, "k", "SELECT * FROM t").contains("COUNT"));
}
#[test]
fn row_estimate_sql_is_scan_free_or_skipped_per_dialect() {
let pg = row_estimate_sql(SourceType::Postgres, "warranty").expect("PG has an estimate");
assert!(pg.contains("reltuples") && pg.contains("pg_class"), "{pg}");
assert!(!pg.contains("COUNT"), "estimate must not scan: {pg}");
let ms =
row_estimate_sql(SourceType::Mssql, "dbo.warranty").expect("MSSQL has an estimate");
assert!(
ms.contains("dm_db_partition_stats") && ms.contains("OBJECT_ID('dbo.warranty')"),
"{ms}"
);
assert!(!ms.contains("COUNT"), "estimate must not scan: {ms}");
assert!(row_estimate_sql(SourceType::Mysql, "warranty").is_none());
assert!(row_estimate_sql(SourceType::Mysql, "shop.warranty").is_none());
}
}