use serde::Deserialize;
use std::collections::{BTreeMap, BTreeSet, HashMap};
use std::env;
use std::error::Error;
use std::fs;
use std::io;
use std::path::Path;
const KNOWN_DIALECTS: &[&str] = &[
"bigquery",
"clickhouse",
"databricks",
"duckdb",
"hive",
"mssql",
"mysql",
"postgres",
"redshift",
"snowflake",
"sqlite",
];
fn main() {
if let Err(e) = run() {
eprintln!("Build script error: {e}");
std::process::exit(1);
}
}
fn run() -> Result<(), Box<dyn Error>> {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap_or_default();
let in_packaged_dir = manifest_dir.contains("/target/package/");
if env::var("CARGO_PUBLISH").is_ok() || in_packaged_dir {
println!("cargo:rerun-if-changed=specs/dialect-semantics/");
println!("cargo:rerun-if-changed=build.rs");
return Ok(());
}
let spec_dir = Path::new("specs/dialect-semantics");
if !spec_dir.exists() {
return Err(format!(
"Spec directory not found at {:?}. Expected at crates/flowscope-core/specs/dialect-semantics/",
spec_dir.canonicalize().unwrap_or_else(|_| spec_dir.to_path_buf())
).into());
}
let generated_dir = Path::new("src/generated");
fs::create_dir_all(generated_dir)
.map_err(|e| format!("Failed to create src/generated directory: {e}"))?;
let dialects = load_dialects_json(spec_dir)?;
let functions = load_functions_json(spec_dir)?;
let normalization_overrides = load_normalization_overrides(spec_dir)?;
let scoping_rules = load_scoping_rules(spec_dir)?;
let dialect_behavior = load_dialect_behavior(spec_dir)?;
validate_dialect_coverage(&dialects, &scoping_rules);
generate_mod_rs(generated_dir)?;
generate_case_sensitivity(generated_dir, &dialects, &normalization_overrides)?;
generate_scoping_rules(generated_dir, &scoping_rules)?;
generate_function_rules(generated_dir, &dialect_behavior)?;
generate_functions(generated_dir, &functions)?;
println!("cargo:rerun-if-changed=specs/dialect-semantics/");
println!("cargo:rerun-if-changed=build.rs");
Ok(())
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)] struct DialectSpec {
normalization: String,
#[serde(default)]
pseudocolumns: Vec<String>,
#[serde(default)]
quote_chars: Option<QuoteChars>,
#[serde(default)]
parser_settings: Option<ParserSettings>,
#[serde(default)]
generator_settings: Option<GeneratorSettings>,
#[serde(default)]
type_mapping_count: Option<usize>,
}
#[derive(Debug, Deserialize, Default)]
#[allow(dead_code)] struct QuoteChars {
#[serde(default)]
identifier_quotes: Vec<serde_json::Value>,
#[serde(default)]
string_escapes: Vec<String>,
}
#[derive(Debug, Deserialize, Default)]
#[allow(dead_code)] struct ParserSettings {
#[serde(default)]
tablesample_csv: bool,
#[serde(default)]
log_defaults_to_ln: bool,
}
#[derive(Debug, Deserialize, Default)]
#[allow(dead_code)] struct GeneratorSettings {
#[serde(default)]
limit_fetch: Option<String>,
#[serde(default)]
tablesample_size_is_rows: bool,
#[serde(default)]
locking_reads_supported: bool,
#[serde(default)]
null_ordering_supported: Option<bool>,
#[serde(default)]
ignore_nulls_in_func: bool,
#[serde(default)]
can_implement_array_any: bool,
#[serde(default)]
supports_table_alias_columns: bool,
#[serde(default)]
unpivot_aliases_are_identifiers: bool,
#[serde(default)]
custom_transforms_count: Option<usize>,
}
fn load_dialects_json(spec_dir: &Path) -> Result<BTreeMap<String, DialectSpec>, Box<dyn Error>> {
let path = spec_dir.join("dialects.json");
let content = fs::read_to_string(&path).map_err(|e| format!("Failed to read {path:?}: {e}"))?;
serde_json::from_str(&content).map_err(|e| format!("Failed to parse {path:?}: {e}").into())
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)] struct FunctionDef {
class: String,
categories: Vec<String>,
#[serde(default)]
sql_names: Vec<String>,
#[serde(default)]
arg_types: HashMap<String, serde_json::Value>,
#[serde(default)]
dialects: Vec<String>,
#[serde(default)]
dialect_specific: bool,
}
fn load_functions_json(spec_dir: &Path) -> Result<BTreeMap<String, FunctionDef>, Box<dyn Error>> {
let path = spec_dir.join("functions.json");
let content = fs::read_to_string(&path).map_err(|e| format!("Failed to read {path:?}: {e}"))?;
serde_json::from_str(&content).map_err(|e| format!("Failed to parse {path:?}: {e}").into())
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)] struct NormalizationOverride {
normalization_strategy: String,
has_custom_normalization: bool,
#[serde(default)]
override_reason: Option<String>,
#[serde(default)]
udf_case_sensitive: Option<bool>,
#[serde(default)]
qualified_table_case_sensitive: Option<bool>,
}
fn load_normalization_overrides(
spec_dir: &Path,
) -> Result<BTreeMap<String, NormalizationOverride>, Box<dyn Error>> {
let path = spec_dir.join("normalization_overrides.toml");
let content = fs::read_to_string(&path).map_err(|e| format!("Failed to read {path:?}: {e}"))?;
toml::from_str(&content).map_err(|e| format!("Failed to parse {path:?}: {e}").into())
}
#[derive(Debug, Deserialize)]
struct ScopingRule {
alias_in_group_by: bool,
alias_in_having: bool,
alias_in_order_by: bool,
lateral_column_alias: bool,
}
fn load_scoping_rules(spec_dir: &Path) -> Result<BTreeMap<String, ScopingRule>, Box<dyn Error>> {
let path = spec_dir.join("scoping_rules.toml");
let content = fs::read_to_string(&path).map_err(|e| format!("Failed to read {path:?}: {e}"))?;
toml::from_str(&content).map_err(|e| format!("Failed to parse {path:?}: {e}").into())
}
#[derive(Debug, Deserialize)]
struct DialectBehavior {
null_ordering: BTreeMap<String, String>,
unnest: UnnestBehavior,
date_functions: BTreeMap<String, BTreeMap<String, toml::Value>>,
}
#[derive(Debug, Deserialize)]
struct UnnestBehavior {
implicit_unnest: Vec<String>,
}
fn load_dialect_behavior(spec_dir: &Path) -> Result<DialectBehavior, Box<dyn Error>> {
let path = spec_dir.join("dialect_behavior.toml");
let content = fs::read_to_string(&path).map_err(|e| format!("Failed to read {path:?}: {e}"))?;
toml::from_str(&content).map_err(|e| format!("Failed to parse {path:?}: {e}").into())
}
fn validate_dialect_coverage(
dialects: &BTreeMap<String, DialectSpec>,
scoping: &BTreeMap<String, ScopingRule>,
) {
let mut warnings = Vec::new();
for dialect in KNOWN_DIALECTS {
if !dialects.contains_key(*dialect) {
warnings.push(format!("Dialect '{dialect}' missing from dialects.json"));
}
if !scoping.contains_key(*dialect) {
warnings.push(format!(
"Dialect '{dialect}' missing from scoping_rules.toml"
));
}
}
for warning in &warnings {
println!("cargo:warning={warning}");
}
}
fn write_if_changed(path: &Path, content: &str) -> Result<(), Box<dyn Error>> {
let write_needed = match fs::read_to_string(path) {
Ok(existing) => existing != content,
Err(err) if err.kind() == io::ErrorKind::NotFound => true,
Err(err) => return Err(format!("Failed to read {path:?}: {err}").into()),
};
if write_needed {
if let Err(err) = fs::write(path, content) {
return Err(format!("Failed to write {path:?}: {err}").into());
}
}
Ok(())
}
fn generate_mod_rs(dir: &Path) -> Result<(), Box<dyn Error>> {
let content = r#"//! Generated dialect semantic code.
//!
//! DO NOT EDIT MANUALLY - generated by build.rs from specs/dialect-semantics/
pub mod case_sensitivity;
pub mod function_rules;
pub mod functions;
mod scoping_rules;
pub use case_sensitivity::*;
pub use function_rules::*;
pub use functions::*;
// scoping_rules adds methods to Dialect via impl, no re-export needed
"#;
write_if_changed(&dir.join("mod.rs"), content)
}
fn generate_case_sensitivity(
dir: &Path,
dialects: &BTreeMap<String, DialectSpec>,
overrides: &BTreeMap<String, NormalizationOverride>,
) -> Result<(), Box<dyn Error>> {
let mut code = String::from(
r#"//! Case sensitivity rules per dialect.
//!
//! Generated from dialects.json and normalization_overrides.toml
//!
//! This module defines how SQL identifiers (table names, column names, etc.)
//! should be normalized for comparison. Different SQL dialects have different
//! rules for identifier case sensitivity.
use std::borrow::Cow;
use crate::Dialect;
/// Normalization strategy for identifier handling.
///
/// SQL dialects differ in how they handle identifier case. This enum represents
/// the different strategies used for normalizing identifiers during analysis.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NormalizationStrategy {
/// Fold to lowercase (Postgres, Redshift)
Lowercase,
/// Fold to uppercase (Snowflake, Oracle)
Uppercase,
/// Case-insensitive comparison without folding
CaseInsensitive,
/// Case-sensitive, preserve exactly
CaseSensitive,
}
impl NormalizationStrategy {
/// Applies this normalization strategy to a string.
///
/// Returns a `Cow<str>` to avoid allocation when no transformation is needed
/// (i.e., for `CaseSensitive` strategy or when the string is already in the
/// correct case).
///
/// For `CaseInsensitive`, lowercase folding is used as the canonical form.
///
/// # Example
///
/// ```
/// use std::borrow::Cow;
/// use flowscope_core::generated::NormalizationStrategy;
///
/// let strategy = NormalizationStrategy::Lowercase;
/// assert_eq!(strategy.apply("MyTable"), "mytable");
///
/// // CaseSensitive returns a borrowed reference (no allocation)
/// let strategy = NormalizationStrategy::CaseSensitive;
/// assert!(matches!(strategy.apply("MyTable"), Cow::Borrowed(_)));
/// ```
pub fn apply<'a>(&self, s: &'a str) -> Cow<'a, str> {
match self {
Self::CaseSensitive => Cow::Borrowed(s),
Self::Lowercase | Self::CaseInsensitive => {
// Optimization: only allocate if the string contains uppercase chars
if s.chars().any(|c| c.is_uppercase()) {
Cow::Owned(s.to_lowercase())
} else {
Cow::Borrowed(s)
}
}
Self::Uppercase => {
// Optimization: only allocate if the string contains lowercase chars
if s.chars().any(|c| c.is_lowercase()) {
Cow::Owned(s.to_uppercase())
} else {
Cow::Borrowed(s)
}
}
}
}
}
impl Dialect {
/// Get the normalization strategy for this dialect.
pub const fn normalization_strategy(&self) -> NormalizationStrategy {
match self {
"#,
);
for (dialect, spec) in dialects {
if let Some(variant) = dialect_to_variant(dialect) {
let strategy = match spec.normalization.as_str() {
"lowercase" => "NormalizationStrategy::Lowercase",
"uppercase" => "NormalizationStrategy::Uppercase",
"case_insensitive" => "NormalizationStrategy::CaseInsensitive",
"case_sensitive" => "NormalizationStrategy::CaseSensitive",
other => {
println!(
"cargo:warning=Unknown normalization '{other}' for dialect '{dialect}', using CaseInsensitive"
);
"NormalizationStrategy::CaseInsensitive"
}
};
code.push_str(&format!(" Dialect::{variant} => {strategy},\n"));
}
}
code.push_str(" Dialect::Generic => NormalizationStrategy::CaseInsensitive,\n");
code.push_str(" Dialect::Ansi => NormalizationStrategy::Uppercase,\n");
code.push_str(" }\n }\n\n");
let custom_dialects: Vec<_> = overrides
.iter()
.filter(|(_, o)| o.has_custom_normalization)
.map(|(d, _)| d)
.collect();
if custom_dialects.is_empty() {
code.push_str(
r#" /// Returns true if this dialect has custom normalization logic
/// that cannot be captured by a simple strategy.
pub const fn has_custom_normalization(&self) -> bool {
false
}
"#,
);
} else {
let variants: Vec<_> = custom_dialects
.iter()
.filter_map(|d| dialect_to_variant(d))
.map(|v| format!("Dialect::{v}"))
.collect();
code.push_str(&format!(
r#" /// Returns true if this dialect has custom normalization logic
/// that cannot be captured by a simple strategy.
pub const fn has_custom_normalization(&self) -> bool {{
matches!(self, {})
}}
"#,
variants.join(" | ")
));
}
code.push_str(
r#"
/// Get pseudocolumns for this dialect (implicit columns like _PARTITIONTIME).
pub fn pseudocolumns(&self) -> &'static [&'static str] {
match self {
"#,
);
for (dialect, spec) in dialects {
if !spec.pseudocolumns.is_empty() {
if let Some(variant) = dialect_to_variant(dialect) {
let cols: Vec<_> = spec
.pseudocolumns
.iter()
.map(|s| format!("\"{s}\""))
.collect();
let cols_str = cols.join(", ");
code.push_str(&format!(
" Dialect::{variant} => &[{cols_str}],\n"
));
}
}
}
code.push_str(" _ => &[],\n");
code.push_str(" }\n }\n");
code.push_str(
r#"
/// Get the identifier quote characters for this dialect.
/// Note: Some dialects use paired quotes (like SQLite's []) which are represented
/// as single characters here - the opening bracket.
pub fn identifier_quotes(&self) -> &'static [&'static str] {
match self {
"#,
);
for (dialect, spec) in dialects {
if let Some(ref qc) = spec.quote_chars {
if !qc.identifier_quotes.is_empty() {
if let Some(variant) = dialect_to_variant(dialect) {
let quotes: Vec<_> = qc
.identifier_quotes
.iter()
.filter_map(|v| {
match v {
serde_json::Value::String(s) => {
let escaped = s.escape_default();
Some(format!("\"{escaped}\""))
}
serde_json::Value::Array(arr) => {
arr.first().and_then(|v| v.as_str()).map(|s| {
let escaped = s.escape_default();
format!("\"{escaped}\"")
})
}
_ => None,
}
})
.collect();
if !quotes.is_empty() {
let quotes_str = quotes.join(", ");
code.push_str(&format!(
" Dialect::{variant} => &[{quotes_str}],\n"
));
}
}
}
}
}
code.push_str(" _ => &[\"\\\"\"],\n"); code.push_str(" }\n }\n}\n");
write_if_changed(&dir.join("case_sensitivity.rs"), &code)
}
fn generate_scoping_rules(
dir: &Path,
rules: &BTreeMap<String, ScopingRule>,
) -> Result<(), Box<dyn Error>> {
let mut code = String::from(
r#"//! Alias visibility and scoping rules per dialect.
//!
//! Generated from scoping_rules.toml
use crate::Dialect;
impl Dialect {
/// Whether SELECT aliases can be referenced in GROUP BY.
pub const fn alias_in_group_by(&self) -> bool {
match self {
"#,
);
for (dialect, rule) in rules {
if let Some(variant) = dialect_to_variant(dialect) {
let val = rule.alias_in_group_by;
code.push_str(&format!(" Dialect::{variant} => {val},\n"));
}
}
code.push_str(" _ => false, // Default: strict (Postgres-like)\n");
code.push_str(" }\n }\n\n");
code.push_str(
r#" /// Whether SELECT aliases can be referenced in HAVING.
pub const fn alias_in_having(&self) -> bool {
match self {
"#,
);
for (dialect, rule) in rules {
if let Some(variant) = dialect_to_variant(dialect) {
let val = rule.alias_in_having;
code.push_str(&format!(" Dialect::{variant} => {val},\n"));
}
}
code.push_str(" _ => false,\n");
code.push_str(" }\n }\n\n");
code.push_str(
r#" /// Whether SELECT aliases can be referenced in ORDER BY.
pub const fn alias_in_order_by(&self) -> bool {
match self {
"#,
);
for (dialect, rule) in rules {
if let Some(variant) = dialect_to_variant(dialect) {
let val = rule.alias_in_order_by;
code.push_str(&format!(" Dialect::{variant} => {val},\n"));
}
}
code.push_str(" _ => true, // ORDER BY alias is widely supported\n");
code.push_str(" }\n }\n\n");
code.push_str(
r#" /// Whether lateral column aliases are supported (referencing earlier SELECT items).
pub const fn lateral_column_alias(&self) -> bool {
match self {
"#,
);
for (dialect, rule) in rules {
if let Some(variant) = dialect_to_variant(dialect) {
let val = rule.lateral_column_alias;
code.push_str(&format!(" Dialect::{variant} => {val},\n"));
}
}
code.push_str(" _ => false,\n");
code.push_str(" }\n }\n}\n");
write_if_changed(&dir.join("scoping_rules.rs"), &code)
}
fn generate_function_rules(dir: &Path, behavior: &DialectBehavior) -> Result<(), Box<dyn Error>> {
let mut code = String::from(
r#"//! Function argument handling rules per dialect.
//!
//! Generated from dialect_behavior.toml
//!
//! This module provides dialect-aware rules for function argument handling,
//! particularly for date/time functions where certain arguments are keywords
//! (like `YEAR`, `MONTH`) rather than column references.
use crate::Dialect;
/// Returns argument indices to skip when extracting column references from a function call.
///
/// Certain SQL functions take keyword arguments (e.g., `DATEDIFF(YEAR, start, end)` in Snowflake)
/// that should not be treated as column references during lineage analysis. This function
/// returns the indices of such arguments for the given function and dialect.
///
/// # Arguments
///
/// * `dialect` - The SQL dialect being analyzed
/// * `func_name` - The function name (case-insensitive, underscore-insensitive)
///
/// # Returns
///
/// A slice of argument indices (0-based) to skip. Returns an empty slice for
/// unknown functions or functions without skip rules.
///
/// # Example
///
/// ```ignore
/// // In Snowflake, DATEDIFF takes a unit as the first argument
/// let skip = skip_args_for_function(Dialect::Snowflake, "DATEDIFF");
/// assert_eq!(skip, &[0]); // Skip first argument (the unit)
///
/// // Both DATEADD and DATE_ADD match the same rules
/// let skip1 = skip_args_for_function(Dialect::Snowflake, "DATEADD");
/// let skip2 = skip_args_for_function(Dialect::Snowflake, "DATE_ADD");
/// assert_eq!(skip1, skip2);
/// ```
pub fn skip_args_for_function(dialect: Dialect, func_name: &str) -> &'static [usize] {
// Normalize: lowercase and remove underscores to handle both DATEADD and DATE_ADD variants
let func_normalized: String = func_name
.chars()
.filter(|c| *c != '_')
.map(|c| c.to_ascii_lowercase())
.collect();
match func_normalized.as_str() {
"#,
);
for (func_name, dialect_rules) in &behavior.date_functions {
let func_normalized: String = func_name
.chars()
.filter(|c| *c != '_')
.flat_map(|c| c.to_lowercase())
.collect();
let has_default = dialect_rules.contains_key("_default");
let dialect_specific_rules: Vec<_> = dialect_rules
.iter()
.filter(|(d, _)| *d != "_default" && dialect_to_variant(d).is_some())
.collect();
if dialect_specific_rules.is_empty() {
if has_default {
let default_indices = parse_skip_indices(dialect_rules.get("_default").unwrap());
if default_indices.is_empty() {
code.push_str(&format!(" \"{func_normalized}\" => &[],\n"));
} else {
let idx_str = default_indices
.iter()
.map(|i| i.to_string())
.collect::<Vec<_>>()
.join(", ");
code.push_str(&format!(" \"{func_normalized}\" => &[{idx_str}],\n"));
}
} else {
code.push_str(&format!(" \"{func_normalized}\" => &[],\n"));
}
continue;
}
code.push_str(&format!(
" \"{func_normalized}\" => match dialect {{\n"
));
for (dialect, value) in dialect_rules {
if dialect == "_default" {
continue;
}
if let Some(variant) = dialect_to_variant(dialect) {
let indices = parse_skip_indices(value);
if indices.is_empty() {
code.push_str(&format!(" Dialect::{variant} => &[],\n"));
} else {
let idx_str = indices
.iter()
.map(|i| i.to_string())
.collect::<Vec<_>>()
.join(", ");
code.push_str(&format!(
" Dialect::{variant} => &[{idx_str}],\n"
));
}
}
}
if has_default {
let default_indices = parse_skip_indices(dialect_rules.get("_default").unwrap());
if default_indices.is_empty() {
code.push_str(" _ => &[],\n");
} else {
let idx_str = default_indices
.iter()
.map(|i| i.to_string())
.collect::<Vec<_>>()
.join(", ");
code.push_str(&format!(" _ => &[{idx_str}],\n"));
}
} else {
code.push_str(" _ => &[],\n");
}
code.push_str(" },\n");
}
code.push_str(" _ => &[],\n");
code.push_str(" }\n}\n\n");
code.push_str(
r#"
/// NULL ordering behavior in ORDER BY.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NullOrdering {
/// NULLs sort as larger than all other values (NULLS LAST for ASC)
NullsAreLarge,
/// NULLs sort as smaller than all other values (NULLS FIRST for ASC)
NullsAreSmall,
/// NULLs always sort last regardless of ASC/DESC
NullsAreLast,
}
impl Dialect {
/// Get the default NULL ordering behavior for this dialect.
pub const fn null_ordering(&self) -> NullOrdering {
match self {
"#,
);
for (dialect, ordering) in &behavior.null_ordering {
if let Some(variant) = dialect_to_variant(dialect) {
let ordering_variant = match ordering.as_str() {
"nulls_are_large" => "NullOrdering::NullsAreLarge",
"nulls_are_small" => "NullOrdering::NullsAreSmall",
"nulls_are_last" => "NullOrdering::NullsAreLast",
_ => "NullOrdering::NullsAreLast",
};
code.push_str(&format!(
" Dialect::{variant} => {ordering_variant},\n"
));
}
}
code.push_str(" _ => NullOrdering::NullsAreLast,\n");
code.push_str(" }\n }\n\n");
let implicit_variants: Vec<_> = behavior
.unnest
.implicit_unnest
.iter()
.filter_map(|d| dialect_to_variant(d))
.map(|v| format!("Dialect::{v}"))
.collect();
if implicit_variants.is_empty() {
code.push_str(
r#" /// Whether this dialect supports implicit UNNEST (no CROSS JOIN needed).
pub const fn supports_implicit_unnest(&self) -> bool {
false
}
}
"#,
);
} else {
code.push_str(&format!(
r#" /// Whether this dialect supports implicit UNNEST (no CROSS JOIN needed).
pub const fn supports_implicit_unnest(&self) -> bool {{
matches!(self, {})
}}
}}
"#,
implicit_variants.join(" | ")
));
}
write_if_changed(&dir.join("function_rules.rs"), &code)
}
fn class_name_to_sql_name(class_name: &str) -> String {
let mut result = String::new();
let mut prev_was_upper = false;
let mut prev_was_letter = false;
for (i, c) in class_name.chars().enumerate() {
if c.is_uppercase() {
if i > 0 && prev_was_letter {
let next_is_lower = class_name
.chars()
.nth(i + 1)
.is_some_and(|nc| nc.is_lowercase());
if !prev_was_upper || next_is_lower {
result.push('_');
}
}
result.push(c.to_ascii_lowercase());
prev_was_upper = true;
} else {
result.push(c.to_ascii_lowercase());
prev_was_upper = false;
}
prev_was_letter = c.is_alphabetic();
}
result
}
fn generate_functions(
dir: &Path,
functions: &BTreeMap<String, FunctionDef>,
) -> Result<(), Box<dyn Error>> {
let mut aggregates: BTreeSet<String> = BTreeSet::new();
let mut windows: BTreeSet<String> = BTreeSet::new();
let mut udtfs: BTreeSet<String> = BTreeSet::new();
for def in functions.values() {
let sql_name = class_name_to_sql_name(&def.class);
for cat in &def.categories {
match cat.as_str() {
"aggregate" => {
aggregates.insert(sql_name.clone());
}
"window" => {
windows.insert(sql_name.clone());
}
"udtf" => {
udtfs.insert(sql_name.clone());
}
_ => {}
}
}
}
let mut code = String::from(
r#"//! Function classification sets.
//!
//! Generated from functions.json
//!
//! This module provides sets of SQL function names categorized by their behavior
//! (aggregate, window, table-generating). These classifications are used during
//! lineage analysis to determine how expressions should be analyzed.
use std::collections::HashSet;
use std::sync::LazyLock;
"#,
);
let agg_count = aggregates.len();
code.push_str(&format!("/// Aggregate functions ({agg_count} total).\n"));
code.push_str(
"pub static AGGREGATE_FUNCTIONS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {\n",
);
code.push_str(" let mut set = HashSet::new();\n");
for func in &aggregates {
code.push_str(&format!(" set.insert(\"{func}\");\n"));
}
code.push_str(" set\n});\n\n");
let win_count = windows.len();
code.push_str(&format!("/// Window functions ({win_count} total).\n"));
code.push_str(
"pub static WINDOW_FUNCTIONS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {\n",
);
code.push_str(" let mut set = HashSet::new();\n");
for func in &windows {
code.push_str(&format!(" set.insert(\"{func}\");\n"));
}
code.push_str(" set\n});\n\n");
let udtf_count = udtfs.len();
code.push_str(&format!(
"/// Table-generating functions / UDTFs ({udtf_count} total).\n"
));
code.push_str(
"pub static UDTF_FUNCTIONS: LazyLock<HashSet<&'static str>> = LazyLock::new(|| {\n",
);
code.push_str(" let mut set = HashSet::new();\n");
for func in &udtfs {
code.push_str(&format!(" set.insert(\"{func}\");\n"));
}
code.push_str(" set\n});\n\n");
code.push_str(
r#"/// Checks if a function is an aggregate function (e.g., SUM, COUNT, AVG).
///
/// Aggregate functions combine multiple input rows into a single output value.
/// This classification is used to detect aggregation in SELECT expressions
/// and validate GROUP BY semantics.
///
/// The check is case-insensitive. Uses ASCII lowercase for performance since
/// SQL function names are always ASCII.
pub fn is_aggregate_function(name: &str) -> bool {
// SQL function names are ASCII, so we can use the faster ASCII lowercase
let lower = name.to_ascii_lowercase();
AGGREGATE_FUNCTIONS.contains(lower.as_str())
}
/// Checks if a function is a window function (e.g., ROW_NUMBER, RANK, LAG).
///
/// Window functions perform calculations across a set of rows related to
/// the current row, without collapsing them into a single output.
///
/// The check is case-insensitive. Uses ASCII lowercase for performance since
/// SQL function names are always ASCII.
pub fn is_window_function(name: &str) -> bool {
// SQL function names are ASCII, so we can use the faster ASCII lowercase
let lower = name.to_ascii_lowercase();
WINDOW_FUNCTIONS.contains(lower.as_str())
}
/// Checks if a function is a table-generating function / UDTF (e.g., UNNEST, EXPLODE).
///
/// UDTFs return multiple rows for each input row, expanding the result set.
/// This classification affects how lineage is tracked through these functions.
///
/// The check is case-insensitive. Uses ASCII lowercase for performance since
/// SQL function names are always ASCII.
pub fn is_udtf_function(name: &str) -> bool {
// SQL function names are ASCII, so we can use the faster ASCII lowercase
let lower = name.to_ascii_lowercase();
UDTF_FUNCTIONS.contains(lower.as_str())
}
"#,
);
write_if_changed(&dir.join("functions.rs"), &code)
}
fn dialect_to_variant(dialect: &str) -> Option<&'static str> {
match dialect.to_lowercase().as_str() {
"bigquery" => Some("Bigquery"),
"clickhouse" => Some("Clickhouse"),
"databricks" => Some("Databricks"),
"duckdb" => Some("Duckdb"),
"hive" => Some("Hive"),
"mssql" | "tsql" => Some("Mssql"),
"mysql" => Some("Mysql"),
"postgres" => Some("Postgres"),
"redshift" => Some("Redshift"),
"snowflake" => Some("Snowflake"),
"sqlite" => Some("Sqlite"),
"doris" | "drill" | "oracle" | "presto" | "spark" | "starrocks" | "tableau"
| "teradata" | "trino" => None,
_ => {
println!("cargo:warning=Unknown dialect '{dialect}' in specs");
None
}
}
}
fn parse_skip_indices(value: &toml::Value) -> Vec<usize> {
match value {
toml::Value::Array(arr) => arr
.iter()
.filter_map(|v| v.as_integer().map(|i| i as usize))
.collect(),
_ => vec![],
}
}