use crate::config::SourceType;
pub(crate) fn quote_ident(source_type: SourceType, name: &str) -> String {
match source_type {
SourceType::Postgres => format!("\"{}\"", name.replace('"', "\"\"")),
SourceType::Mysql => format!("`{}`", name.replace('`', "``")),
SourceType::Mssql => format!("[{}]", name.replace(']', "]]")),
}
}
pub(crate) fn strip_select_star_from(base_query: &str) -> Option<&str> {
let trimmed = base_query.trim();
let after = strip_prefix_ascii_ci(trimmed, "select")
.map(str::trim_start)
.and_then(|s| s.strip_prefix('*'))
.map(str::trim_start)
.and_then(|s| strip_prefix_ascii_ci(s, "from"))?;
let rest = after.trim_start();
let end = rest
.find(|c: char| !(c.is_ascii_alphanumeric() || c == '_' || c == '.'))
.unwrap_or(rest.len());
let ident = &rest[..end];
let parts: Vec<&str> = ident.split('.').collect();
if !(1..=2).contains(&parts.len()) {
return None;
}
for p in &parts {
let mut chars = p.chars();
match chars.next() {
Some(c) if c.is_ascii_alphabetic() || c == '_' => {
if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
return None;
}
}
_ => return None,
}
}
if !rest[end..].trim().is_empty() {
return None;
}
Some(ident)
}
fn strip_prefix_ascii_ci<'a>(s: &'a str, prefix: &str) -> Option<&'a str> {
if s.len() >= prefix.len() && s[..prefix.len()].eq_ignore_ascii_case(prefix) {
Some(&s[prefix.len()..])
} else {
None
}
}
pub(crate) fn aggregate_sql(
source_type: SourceType,
agg: &str,
col: &str,
base_query: &str,
) -> String {
let q = quote_ident(source_type, col);
match strip_select_star_from(base_query) {
Some(table_ident) => format!("SELECT {agg}({q}) FROM {table_ident}"),
None => format!("SELECT {agg}({q}) FROM ({base_query}) AS _rivet"),
}
}
pub(crate) fn row_estimate_sql(source_type: SourceType, table_ident: &str) -> Option<String> {
match source_type {
SourceType::Postgres => Some(format!(
"SELECT GREATEST(reltuples, 0)::bigint FROM pg_class WHERE oid = '{table_ident}'::regclass"
)),
SourceType::Mysql => None,
SourceType::Mssql => Some(format!(
"SELECT SUM(p.row_count) FROM sys.dm_db_partition_stats p \
WHERE p.object_id = OBJECT_ID('{table_ident}') AND p.index_id IN (0,1)"
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn postgres_plain_identifier() {
assert_eq!(quote_ident(SourceType::Postgres, "id"), "\"id\"");
assert_eq!(
quote_ident(SourceType::Postgres, "created_at"),
"\"created_at\""
);
}
#[test]
fn postgres_escapes_internal_double_quotes() {
assert_eq!(
quote_ident(SourceType::Postgres, "col\"name"),
"\"col\"\"name\""
);
}
#[test]
fn mysql_plain_identifier() {
assert_eq!(quote_ident(SourceType::Mysql, "id"), "`id`");
assert_eq!(quote_ident(SourceType::Mysql, "created_at"), "`created_at`");
}
#[test]
fn mysql_escapes_internal_backticks() {
assert_eq!(quote_ident(SourceType::Mysql, "col`name"), "`col``name`");
}
#[test]
fn strip_select_star_from_accepts_simple_table_forms() {
assert_eq!(
strip_select_star_from("SELECT * FROM events"),
Some("events")
);
assert_eq!(
strip_select_star_from("select * from public.orders"),
Some("public.orders")
);
}
#[test]
fn strip_select_star_from_rejects_anything_with_a_clause() {
assert!(strip_select_star_from("SELECT * FROM events WHERE id > 1").is_none());
assert!(strip_select_star_from("SELECT id FROM events").is_none());
assert!(strip_select_star_from("SELECT * FROM a JOIN b").is_none());
assert!(strip_select_star_from("SELECT * FROM events;").is_none());
assert!(strip_select_star_from("SELECT * FROM a.b.c").is_none());
}
#[test]
fn aggregate_sql_fast_path_on_table_shortcut() {
assert_eq!(
aggregate_sql(
SourceType::Postgres,
"min",
"created_at",
"SELECT * FROM events"
),
"SELECT min(\"created_at\") FROM events"
);
}
#[test]
fn aggregate_sql_wraps_a_real_query() {
assert_eq!(
aggregate_sql(
SourceType::Postgres,
"max",
"created_at",
"SELECT id, created_at FROM events WHERE x"
),
"SELECT max(\"created_at\") FROM (SELECT id, created_at FROM events WHERE x) AS _rivet"
);
assert!(
aggregate_sql(SourceType::Mysql, "min", "d", "SELECT d FROM t WHERE 1")
.contains("min(`d`)")
);
}
#[test]
fn row_estimate_sql_is_scan_free_or_skipped_per_dialect() {
let pg = row_estimate_sql(SourceType::Postgres, "warranty").expect("PG has an estimate");
assert!(pg.contains("reltuples") && pg.contains("pg_class"), "{pg}");
assert!(!pg.contains("COUNT"), "estimate must not scan: {pg}");
let ms =
row_estimate_sql(SourceType::Mssql, "dbo.warranty").expect("MSSQL has an estimate");
assert!(
ms.contains("dm_db_partition_stats") && ms.contains("OBJECT_ID('dbo.warranty')"),
"{ms}"
);
assert!(!ms.contains("COUNT"), "estimate must not scan: {ms}");
assert!(row_estimate_sql(SourceType::Mysql, "warranty").is_none());
assert!(row_estimate_sql(SourceType::Mysql, "shop.warranty").is_none());
}
}