Skip to main content

quack_rs/
sql_macro.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tom F. <https://github.com/tomtom215/>
3// My way of giving something small back to the open source community
4// and encouraging more Rust development!
5
6//! SQL macro registration for `DuckDB` extensions.
7//!
8//! SQL macros let you package reusable SQL expressions and queries as
9//! named `DuckDB` functions — no FFI callbacks required. This module
10//! provides a safe Rust builder for creating both scalar and table macros
11//! via `CREATE OR REPLACE MACRO` statements executed during extension
12//! initialization.
13//!
14//! # Macro types
15//!
16//! | Type | SQL | Returns |
17//! |------|-----|---------|
18//! | **Scalar** | `AS (expression)` | one value per row |
19//! | **Table** | `AS TABLE query`  | a table |
20//!
21//! # Example
22//!
23//! ```rust,no_run
24//! use quack_rs::sql_macro::SqlMacro;
25//! use quack_rs::error::ExtensionError;
26//!
27//! fn register(con: libduckdb_sys::duckdb_connection) -> Result<(), ExtensionError> {
28//!     unsafe {
29//!         // Scalar macro: clamp(x, lo, hi) — no C++ needed!
30//!         SqlMacro::scalar("clamp", &["x", "lo", "hi"], "greatest(lo, least(hi, x))")?
31//!             .register(con)?;
32//!
33//!         // Table macro: active_rows(tbl) — returns filtered rows
34//!         SqlMacro::table("active_rows", &["tbl"], "SELECT * FROM tbl WHERE active = true")?
35//!             .register(con)?;
36//!     }
37//!     Ok(())
38//! }
39//! ```
40//!
41//! # SQL injection safety
42//!
43//! Macro names and parameter names are validated against
44//! [`validate_function_name`]:
45//! only `[a-z][a-z0-9_]*` identifiers are accepted. These names are
46//! interpolated literally into the generated SQL (no quoting required
47//! because they are already restricted to safe characters).
48//!
49//! The SQL body (`expression` / `query`) is your own extension code, not
50//! user-supplied input. **Never build macro bodies from untrusted runtime
51//! data.** There is no escaping applied to the body.
52
53use std::ffi::{CStr, CString};
54
55use libduckdb_sys::{
56    duckdb_connection, duckdb_destroy_result, duckdb_query, duckdb_result, duckdb_result_error,
57    DuckDBSuccess,
58};
59
60use crate::error::ExtensionError;
61use crate::validate::validate_function_name;
62
63/// The body of a SQL macro: a scalar expression or a table query.
64///
65/// Constructed implicitly by [`SqlMacro::scalar`] and [`SqlMacro::table`].
66#[derive(Debug, Clone, PartialEq, Eq)]
67pub enum MacroBody {
68    /// A SQL expression — generates `AS (expression)`.
69    ///
70    /// Example: `"greatest(lo, least(hi, x))"`
71    Scalar(String),
72
73    /// A SQL query — generates `AS TABLE query`.
74    ///
75    /// Example: `"SELECT * FROM tbl WHERE active = true"`
76    Table(String),
77}
78
79/// A SQL macro definition ready to be registered with `DuckDB`.
80///
81/// Use [`SqlMacro::scalar`] or [`SqlMacro::table`] to construct, then call
82/// [`SqlMacro::register`] to install. Use [`SqlMacro::to_sql`] to inspect
83/// the generated `CREATE MACRO` statement without a live connection.
84///
85/// # Example
86///
87/// ```rust
88/// use quack_rs::sql_macro::SqlMacro;
89///
90/// let m = SqlMacro::scalar("add", &["a", "b"], "a + b").unwrap();
91/// assert_eq!(m.to_sql(), "CREATE OR REPLACE MACRO add(a, b) AS (a + b)");
92/// ```
93#[derive(Debug, Clone)]
94pub struct SqlMacro {
95    name: String,
96    params: Vec<String>,
97    body: MacroBody,
98}
99
100impl SqlMacro {
101    /// Creates a scalar SQL macro definition.
102    ///
103    /// Registers as:
104    /// ```sql
105    /// CREATE OR REPLACE MACRO name(params) AS (expression)
106    /// ```
107    ///
108    /// # Errors
109    ///
110    /// Returns [`ExtensionError`] if `name` or any parameter name is invalid.
111    /// See [`validate_function_name`]
112    /// for naming rules.
113    ///
114    /// # Example
115    ///
116    /// ```rust
117    /// use quack_rs::sql_macro::SqlMacro;
118    ///
119    /// let m = SqlMacro::scalar("clamp", &["x", "lo", "hi"], "greatest(lo, least(hi, x))")?;
120    /// # Ok::<_, quack_rs::error::ExtensionError>(())
121    /// ```
122    pub fn scalar(
123        name: &str,
124        params: &[&str],
125        expression: impl Into<String>,
126    ) -> Result<Self, ExtensionError> {
127        let (name, params) = validate_name_and_params(name, params)?;
128        Ok(Self {
129            name,
130            params,
131            body: MacroBody::Scalar(expression.into()),
132        })
133    }
134
135    /// Creates a table SQL macro definition.
136    ///
137    /// Registers as:
138    /// ```sql
139    /// CREATE OR REPLACE MACRO name(params) AS TABLE query
140    /// ```
141    ///
142    /// # Errors
143    ///
144    /// Returns [`ExtensionError`] if `name` or any parameter name is invalid.
145    ///
146    /// # Example
147    ///
148    /// ```rust
149    /// use quack_rs::sql_macro::SqlMacro;
150    ///
151    /// let m = SqlMacro::table(
152    ///     "active_rows",
153    ///     &["tbl"],
154    ///     "SELECT * FROM tbl WHERE active = true",
155    /// )?;
156    /// # Ok::<_, quack_rs::error::ExtensionError>(())
157    /// ```
158    pub fn table(
159        name: &str,
160        params: &[&str],
161        query: impl Into<String>,
162    ) -> Result<Self, ExtensionError> {
163        let (name, params) = validate_name_and_params(name, params)?;
164        Ok(Self {
165            name,
166            params,
167            body: MacroBody::Table(query.into()),
168        })
169    }
170
171    /// Returns the `CREATE OR REPLACE MACRO` SQL statement for this definition.
172    ///
173    /// Useful for logging, testing, and inspection without a live connection.
174    ///
175    /// # Example
176    ///
177    /// ```rust
178    /// use quack_rs::sql_macro::SqlMacro;
179    ///
180    /// let m = SqlMacro::scalar("add", &["a", "b"], "a + b").unwrap();
181    /// assert_eq!(m.to_sql(), "CREATE OR REPLACE MACRO add(a, b) AS (a + b)");
182    ///
183    /// let t = SqlMacro::table("active_rows", &["tbl"], "SELECT * FROM tbl WHERE active = true").unwrap();
184    /// assert_eq!(
185    ///     t.to_sql(),
186    ///     "CREATE OR REPLACE MACRO active_rows(tbl) AS TABLE SELECT * FROM tbl WHERE active = true"
187    /// );
188    /// ```
189    #[must_use]
190    pub fn to_sql(&self) -> String {
191        let params = self.params.join(", ");
192        match &self.body {
193            MacroBody::Scalar(expr) => {
194                format!(
195                    "CREATE OR REPLACE MACRO {}({}) AS ({})",
196                    self.name, params, expr
197                )
198            }
199            MacroBody::Table(query) => {
200                format!(
201                    "CREATE OR REPLACE MACRO {}({}) AS TABLE {}",
202                    self.name, params, query
203                )
204            }
205        }
206    }
207
208    /// Registers this macro on the given connection.
209    ///
210    /// Executes the `CREATE OR REPLACE MACRO` statement via `duckdb_query`.
211    ///
212    /// # Errors
213    ///
214    /// Returns [`ExtensionError`] if `DuckDB` rejects the SQL statement.
215    /// The error message is extracted from `duckdb_result_error`.
216    ///
217    /// # Safety
218    ///
219    /// `con` must be a valid, open [`duckdb_connection`].
220    pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
221        let sql = self.to_sql();
222        // SAFETY: caller guarantees con is valid and open.
223        unsafe { execute_sql(con, &sql) }
224    }
225
226    /// Returns the macro name.
227    #[must_use]
228    pub fn name(&self) -> &str {
229        &self.name
230    }
231
232    /// Returns the macro parameter names.
233    #[must_use]
234    pub fn params(&self) -> &[String] {
235        &self.params
236    }
237
238    /// Returns the macro body.
239    #[must_use]
240    pub const fn body(&self) -> &MacroBody {
241        &self.body
242    }
243}
244
245/// Validates a macro name and all parameter names using the same rules as
246/// function names: `[a-z_][a-z0-9_]*`, max 256 chars.
247fn validate_name_and_params(
248    name: &str,
249    params: &[&str],
250) -> Result<(String, Vec<String>), ExtensionError> {
251    validate_function_name(name)?;
252    for &param in params {
253        validate_function_name(param)
254            .map_err(|e| ExtensionError::new(format!("invalid parameter name '{param}': {e}")))?;
255    }
256    Ok((
257        name.to_owned(),
258        params.iter().map(|&p| p.to_owned()).collect(),
259    ))
260}
261
262/// Executes a SQL statement on `con`, surfacing any `DuckDB` error.
263///
264/// Always calls `duckdb_destroy_result`, even on failure.
265///
266/// # Safety
267///
268/// `con` must be a valid, open [`duckdb_connection`].
269unsafe fn execute_sql(con: duckdb_connection, sql: &str) -> Result<(), ExtensionError> {
270    let c_sql = CString::new(sql)
271        .map_err(|_| ExtensionError::new("SQL statement contains interior null bytes"))?;
272
273    // Zero-initialize: duckdb_result contains only integer and pointer fields,
274    // all of which are valid when zero / null.
275    //
276    // SAFETY: duckdb_result is a C struct; zero is a valid bit pattern for every field.
277    let mut result: duckdb_result = unsafe { std::mem::zeroed() };
278
279    // SAFETY: con is valid; c_sql is a valid nul-terminated C string.
280    let rc = unsafe { duckdb_query(con, c_sql.as_ptr(), &raw mut result) };
281
282    // Extract the error message before freeing, because duckdb_result_error
283    // returns a pointer into the result's internal buffer.
284    let outcome = if rc == DuckDBSuccess {
285        Ok(())
286    } else {
287        // SAFETY: result was populated by duckdb_query; duckdb_result_error
288        // returns a pointer valid until duckdb_destroy_result.
289        let ptr = unsafe { duckdb_result_error(&raw mut result) };
290        let msg = if ptr.is_null() {
291            "DuckDB macro registration failed (no error message available)".to_string()
292        } else {
293            // SAFETY: ptr is a valid nul-terminated C string owned by the result.
294            unsafe { CStr::from_ptr(ptr) }
295                .to_string_lossy()
296                .into_owned()
297        };
298        Err(ExtensionError::new(msg))
299    };
300
301    // SAFETY: result was populated by duckdb_query and must always be freed.
302    unsafe { duckdb_destroy_result(&raw mut result) };
303
304    outcome
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    // -----------------------------------------------------------------------
312    // to_sql() — pure-Rust, no DuckDB connection needed
313    // -----------------------------------------------------------------------
314
315    #[test]
316    fn scalar_no_params_to_sql() {
317        let m = SqlMacro::scalar("pi", &[], "3.14159265358979").unwrap();
318        assert_eq!(
319            m.to_sql(),
320            "CREATE OR REPLACE MACRO pi() AS (3.14159265358979)"
321        );
322    }
323
324    #[test]
325    fn scalar_one_param_to_sql() {
326        let m = SqlMacro::scalar("double_it", &["x"], "x * 2").unwrap();
327        assert_eq!(
328            m.to_sql(),
329            "CREATE OR REPLACE MACRO double_it(x) AS (x * 2)"
330        );
331    }
332
333    #[test]
334    fn scalar_multiple_params_to_sql() {
335        let m = SqlMacro::scalar("add", &["a", "b"], "a + b").unwrap();
336        assert_eq!(m.to_sql(), "CREATE OR REPLACE MACRO add(a, b) AS (a + b)");
337    }
338
339    #[test]
340    fn scalar_complex_expression_to_sql() {
341        let m =
342            SqlMacro::scalar("clamp", &["x", "lo", "hi"], "greatest(lo, least(hi, x))").unwrap();
343        assert_eq!(
344            m.to_sql(),
345            "CREATE OR REPLACE MACRO clamp(x, lo, hi) AS (greatest(lo, least(hi, x)))"
346        );
347    }
348
349    #[test]
350    fn table_no_params_to_sql() {
351        let m = SqlMacro::table("all_data", &[], "SELECT 1 AS n").unwrap();
352        assert_eq!(
353            m.to_sql(),
354            "CREATE OR REPLACE MACRO all_data() AS TABLE SELECT 1 AS n"
355        );
356    }
357
358    #[test]
359    fn table_with_param_to_sql() {
360        let m = SqlMacro::table(
361            "active_rows",
362            &["tbl"],
363            "SELECT * FROM tbl WHERE active = true",
364        )
365        .unwrap();
366        assert_eq!(
367            m.to_sql(),
368            "CREATE OR REPLACE MACRO active_rows(tbl) AS TABLE SELECT * FROM tbl WHERE active = true"
369        );
370    }
371
372    // -----------------------------------------------------------------------
373    // Name and parameter validation
374    // -----------------------------------------------------------------------
375
376    #[test]
377    fn invalid_macro_name_uppercase_rejected() {
378        assert!(SqlMacro::scalar("MyMacro", &[], "1").is_err());
379    }
380
381    #[test]
382    fn invalid_macro_name_hyphen_rejected() {
383        assert!(SqlMacro::scalar("my-macro", &[], "1").is_err());
384    }
385
386    #[test]
387    fn invalid_macro_name_empty_rejected() {
388        assert!(SqlMacro::scalar("", &[], "1").is_err());
389    }
390
391    #[test]
392    fn invalid_param_uppercase_rejected() {
393        let err = SqlMacro::scalar("f", &["BadParam"], "1").unwrap_err();
394        assert!(err.as_str().contains("BadParam"));
395    }
396
397    #[test]
398    fn invalid_param_hyphen_rejected() {
399        assert!(SqlMacro::scalar("f", &["a-b"], "1").is_err());
400    }
401
402    #[test]
403    fn valid_underscore_prefix_param() {
404        assert!(SqlMacro::scalar("f", &["_x"], "1").is_ok());
405    }
406
407    #[test]
408    fn valid_single_letter_params() {
409        let m = SqlMacro::scalar("clamp", &["x", "lo", "hi"], "1").unwrap();
410        assert_eq!(m.params(), ["x", "lo", "hi"]);
411    }
412
413    #[test]
414    fn name_and_params_stored_correctly() {
415        let m = SqlMacro::scalar("f", &["a", "b", "c"], "a+b+c").unwrap();
416        assert_eq!(m.name(), "f");
417        assert_eq!(m.params(), ["a", "b", "c"]);
418    }
419
420    // -----------------------------------------------------------------------
421    // Body variant accessors
422    // -----------------------------------------------------------------------
423
424    #[test]
425    fn scalar_body_variant() {
426        let m = SqlMacro::scalar("f", &["x"], "x + 1").unwrap();
427        assert_eq!(m.body(), &MacroBody::Scalar("x + 1".to_string()));
428    }
429
430    #[test]
431    fn table_body_variant() {
432        let m = SqlMacro::table("t", &[], "SELECT 1").unwrap();
433        assert_eq!(m.body(), &MacroBody::Table("SELECT 1".to_string()));
434    }
435
436    // -----------------------------------------------------------------------
437    // Clone and Debug
438    // -----------------------------------------------------------------------
439
440    #[test]
441    fn sql_macro_is_cloneable() {
442        let m = SqlMacro::scalar("f", &["x"], "x").unwrap();
443        let m2 = m.clone();
444        assert_eq!(m.to_sql(), m2.to_sql());
445    }
446
447    #[test]
448    fn macro_body_is_eq() {
449        assert_eq!(MacroBody::Scalar("x".into()), MacroBody::Scalar("x".into()));
450        assert_ne!(MacroBody::Scalar("x".into()), MacroBody::Table("x".into()));
451    }
452}