use std::ffi::{CStr, CString};
use libduckdb_sys::{
duckdb_connection, duckdb_destroy_result, duckdb_query, duckdb_result, duckdb_result_error,
DuckDBSuccess,
};
use crate::error::ExtensionError;
use crate::validate::validate_function_name;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MacroBody {
Scalar(String),
Table(String),
}
#[derive(Debug, Clone)]
pub struct SqlMacro {
name: String,
params: Vec<String>,
body: MacroBody,
}
impl SqlMacro {
pub fn scalar(
name: &str,
params: &[&str],
expression: impl Into<String>,
) -> Result<Self, ExtensionError> {
let (name, params) = validate_name_and_params(name, params)?;
Ok(Self {
name,
params,
body: MacroBody::Scalar(expression.into()),
})
}
pub fn table(
name: &str,
params: &[&str],
query: impl Into<String>,
) -> Result<Self, ExtensionError> {
let (name, params) = validate_name_and_params(name, params)?;
Ok(Self {
name,
params,
body: MacroBody::Table(query.into()),
})
}
#[must_use]
pub fn to_sql(&self) -> String {
let params = self.params.join(", ");
match &self.body {
MacroBody::Scalar(expr) => {
format!(
"CREATE OR REPLACE MACRO {}({}) AS ({})",
self.name, params, expr
)
}
MacroBody::Table(query) => {
format!(
"CREATE OR REPLACE MACRO {}({}) AS TABLE {}",
self.name, params, query
)
}
}
}
pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
let sql = self.to_sql();
unsafe { execute_sql(con, &sql) }
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
#[must_use]
pub fn params(&self) -> &[String] {
&self.params
}
#[must_use]
pub const fn body(&self) -> &MacroBody {
&self.body
}
}
fn validate_name_and_params(
name: &str,
params: &[&str],
) -> Result<(String, Vec<String>), ExtensionError> {
validate_function_name(name)?;
for ¶m in params {
validate_function_name(param)
.map_err(|e| ExtensionError::new(format!("invalid parameter name '{param}': {e}")))?;
}
Ok((
name.to_owned(),
params.iter().map(|&p| p.to_owned()).collect(),
))
}
unsafe fn execute_sql(con: duckdb_connection, sql: &str) -> Result<(), ExtensionError> {
let c_sql = CString::new(sql)
.map_err(|_| ExtensionError::new("SQL statement contains interior null bytes"))?;
let mut result: duckdb_result = unsafe { std::mem::zeroed() };
let rc = unsafe { duckdb_query(con, c_sql.as_ptr(), &raw mut result) };
let outcome = if rc == DuckDBSuccess {
Ok(())
} else {
let ptr = unsafe { duckdb_result_error(&raw mut result) };
let msg = if ptr.is_null() {
"DuckDB macro registration failed (no error message available)".to_string()
} else {
unsafe { CStr::from_ptr(ptr) }
.to_string_lossy()
.into_owned()
};
Err(ExtensionError::new(msg))
};
unsafe { duckdb_destroy_result(&raw mut result) };
outcome
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn scalar_no_params_to_sql() {
let m = SqlMacro::scalar("pi", &[], "3.14159265358979").unwrap();
assert_eq!(
m.to_sql(),
"CREATE OR REPLACE MACRO pi() AS (3.14159265358979)"
);
}
#[test]
fn scalar_one_param_to_sql() {
let m = SqlMacro::scalar("double_it", &["x"], "x * 2").unwrap();
assert_eq!(
m.to_sql(),
"CREATE OR REPLACE MACRO double_it(x) AS (x * 2)"
);
}
#[test]
fn scalar_multiple_params_to_sql() {
let m = SqlMacro::scalar("add", &["a", "b"], "a + b").unwrap();
assert_eq!(m.to_sql(), "CREATE OR REPLACE MACRO add(a, b) AS (a + b)");
}
#[test]
fn scalar_complex_expression_to_sql() {
let m =
SqlMacro::scalar("clamp", &["x", "lo", "hi"], "greatest(lo, least(hi, x))").unwrap();
assert_eq!(
m.to_sql(),
"CREATE OR REPLACE MACRO clamp(x, lo, hi) AS (greatest(lo, least(hi, x)))"
);
}
#[test]
fn table_no_params_to_sql() {
let m = SqlMacro::table("all_data", &[], "SELECT 1 AS n").unwrap();
assert_eq!(
m.to_sql(),
"CREATE OR REPLACE MACRO all_data() AS TABLE SELECT 1 AS n"
);
}
#[test]
fn table_with_param_to_sql() {
let m = SqlMacro::table(
"active_rows",
&["tbl"],
"SELECT * FROM tbl WHERE active = true",
)
.unwrap();
assert_eq!(
m.to_sql(),
"CREATE OR REPLACE MACRO active_rows(tbl) AS TABLE SELECT * FROM tbl WHERE active = true"
);
}
#[test]
fn invalid_macro_name_uppercase_rejected() {
assert!(SqlMacro::scalar("MyMacro", &[], "1").is_err());
}
#[test]
fn invalid_macro_name_hyphen_rejected() {
assert!(SqlMacro::scalar("my-macro", &[], "1").is_err());
}
#[test]
fn invalid_macro_name_empty_rejected() {
assert!(SqlMacro::scalar("", &[], "1").is_err());
}
#[test]
fn invalid_param_uppercase_rejected() {
let err = SqlMacro::scalar("f", &["BadParam"], "1").unwrap_err();
assert!(err.as_str().contains("BadParam"));
}
#[test]
fn invalid_param_hyphen_rejected() {
assert!(SqlMacro::scalar("f", &["a-b"], "1").is_err());
}
#[test]
fn valid_underscore_prefix_param() {
assert!(SqlMacro::scalar("f", &["_x"], "1").is_ok());
}
#[test]
fn valid_single_letter_params() {
let m = SqlMacro::scalar("clamp", &["x", "lo", "hi"], "1").unwrap();
assert_eq!(m.params(), ["x", "lo", "hi"]);
}
#[test]
fn name_and_params_stored_correctly() {
let m = SqlMacro::scalar("f", &["a", "b", "c"], "a+b+c").unwrap();
assert_eq!(m.name(), "f");
assert_eq!(m.params(), ["a", "b", "c"]);
}
#[test]
fn scalar_body_variant() {
let m = SqlMacro::scalar("f", &["x"], "x + 1").unwrap();
assert_eq!(m.body(), &MacroBody::Scalar("x + 1".to_string()));
}
#[test]
fn table_body_variant() {
let m = SqlMacro::table("t", &[], "SELECT 1").unwrap();
assert_eq!(m.body(), &MacroBody::Table("SELECT 1".to_string()));
}
#[test]
fn sql_macro_is_cloneable() {
let m = SqlMacro::scalar("f", &["x"], "x").unwrap();
let m2 = m.clone();
assert_eq!(m.to_sql(), m2.to_sql());
}
#[test]
fn macro_body_is_eq() {
assert_eq!(MacroBody::Scalar("x".into()), MacroBody::Scalar("x".into()));
assert_ne!(MacroBody::Scalar("x".into()), MacroBody::Table("x".into()));
}
}