Skip to main content

rusqlite/
functions.rs

1//! Create or redefine SQL functions.
2//!
3//! # Example
4//!
5//! Adding a `regexp` function to a connection in which compiled regular
6//! expressions are cached in a `HashMap`. For an alternative implementation
7//! that uses SQLite's [Function Auxiliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface
8//! to avoid recompiling regular expressions, see the unit tests for this
9//! module.
10//!
11//! ```rust
12//! use regex::Regex;
13//! use rusqlite::functions::FunctionFlags;
14//! use rusqlite::{Connection, Error, Result};
15//! use std::sync::Arc;
16//! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
17//!
18//! fn add_regexp_function(db: &Connection) -> Result<()> {
19//!     db.create_scalar_function(
20//!         "regexp",
21//!         2,
22//!         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
23//!         move |ctx| {
24//!             assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
25//!             let regexp: Arc<Regex> = ctx.get_or_create_aux(0, |vr| -> Result<_, BoxError> {
26//!                 Ok(Regex::new(vr.as_str()?)?)
27//!             })?;
28//!             let is_match = {
29//!                 let text = ctx
30//!                     .get_raw(1)
31//!                     .as_str()
32//!                     .map_err(|e| Error::UserFunctionError(e.into()))?;
33//!
34//!                 regexp.is_match(text)
35//!             };
36//!
37//!             Ok(is_match)
38//!         },
39//!     )
40//! }
41//!
42//! fn main() -> Result<()> {
43//!     let db = Connection::open_in_memory()?;
44//!     add_regexp_function(&db)?;
45//!
46//!     let is_match: bool =
47//!         db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", [], |row| {
48//!             row.get(0)
49//!         })?;
50//!
51//!     assert!(is_match);
52//!     Ok(())
53//! }
54//! ```
55use std::any::Any;
56use std::ffi::{c_int, c_uint, c_void};
57use std::marker::PhantomData;
58use std::ops::Deref;
59use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60use std::ptr;
61use std::slice;
62use std::sync::Arc;
63
64use crate::ffi;
65use crate::ffi::sqlite3_context;
66use crate::ffi::sqlite3_value;
67
68use crate::context::set_result;
69use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};
70use crate::util::free_boxed_value;
71use crate::{str_to_cstring, Connection, Error, InnerConnection, Name, Result};
72
73unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74    if let Error::SqliteFailure(ref err, ref s) = *err {
75        ffi::sqlite3_result_error_code(ctx, err.extended_code);
76        if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
77            ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
78        }
79    } else {
80        ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION);
81        if let Ok(cstr) = str_to_cstring(&err.to_string()) {
82            ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
83        }
84    }
85}
86
87/// Context is a wrapper for the SQLite function
88/// evaluation context.
89pub struct Context<'a> {
90    ctx: *mut sqlite3_context,
91    args: &'a [*mut sqlite3_value],
92}
93
94impl Context<'_> {
95    /// Returns the number of arguments to the function.
96    #[inline]
97    #[must_use]
98    pub fn len(&self) -> usize {
99        self.args.len()
100    }
101
102    /// Returns `true` when there is no argument.
103    #[inline]
104    #[must_use]
105    pub fn is_empty(&self) -> bool {
106        self.args.is_empty()
107    }
108
109    /// Returns the `idx`th argument as a `T`.
110    ///
111    /// # Failure
112    ///
113    /// Will panic if `idx` is greater than or equal to
114    /// [`self.len()`](Context::len).
115    ///
116    /// Will return Err if the underlying SQLite type cannot be converted to a
117    /// `T`.
118    pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
119        let arg = self.args[idx];
120        let value = unsafe { ValueRef::from_value(arg) };
121        FromSql::column_result(value).map_err(|err| match err {
122            FromSqlError::InvalidType => {
123                Error::InvalidFunctionParameterType(idx, value.data_type())
124            }
125            FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
126            FromSqlError::Utf8Error(err) => Error::Utf8Error(idx, err),
127            FromSqlError::Other(err) => {
128                Error::FromSqlConversionFailure(idx, value.data_type(), err)
129            }
130            FromSqlError::InvalidBlobSize { .. } => {
131                Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
132            }
133        })
134    }
135
136    /// Return raw pointer at `idx`
137    /// # Safety
138    /// This function is unsafe because it uses raw pointer and cast
139    #[cfg(feature = "pointer")]
140    pub unsafe fn get_pointer<T: 'static>(
141        &self,
142        idx: usize,
143        ptr_type: &'static std::ffi::CStr,
144    ) -> Option<&T> {
145        let arg = self.args[idx];
146        debug_assert_eq!(unsafe { ffi::sqlite3_value_type(arg) }, ffi::SQLITE_NULL);
147        unsafe {
148            ffi::sqlite3_value_pointer(arg, ptr_type.as_ptr())
149                .cast::<T>()
150                .as_ref()
151        }
152    }
153
154    /// Returns the `idx`th argument as a `ValueRef`.
155    ///
156    /// # Failure
157    ///
158    /// Will panic if `idx` is greater than or equal to
159    /// [`self.len()`](Context::len).
160    #[inline]
161    #[must_use]
162    pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
163        let arg = self.args[idx];
164        unsafe { ValueRef::from_value(arg) }
165    }
166
167    /// Returns the `idx`th argument as a `SqlFnArg`.
168    /// To be used when the SQL function result is one of its arguments.
169    #[inline]
170    #[must_use]
171    pub fn get_arg(&self, idx: usize) -> SqlFnArg {
172        assert!(idx < self.len());
173        SqlFnArg { idx }
174    }
175
176    /// Returns the subtype of `idx`th argument.
177    ///
178    /// # Failure
179    ///
180    /// Will panic if `idx` is greater than or equal to
181    /// [`self.len()`](Context::len).
182    pub fn get_subtype(&self, idx: usize) -> c_uint {
183        let arg = self.args[idx];
184        unsafe { ffi::sqlite3_value_subtype(arg) }
185    }
186
187    /// Fetch or insert the auxiliary data associated with a particular
188    /// parameter. This is intended to be an easier-to-use way of fetching it
189    /// compared to calling [`get_aux`](Context::get_aux) and
190    /// [`set_aux`](Context::set_aux) separately.
191    ///
192    /// See `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
193    /// this feature, or the unit tests of this module for an example.
194    ///
195    /// # Failure
196    ///
197    /// Will panic if `arg` is greater than or equal to
198    /// [`self.len()`](Context::len).
199    pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
200    where
201        T: Send + Sync + 'static,
202        E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
203        F: FnOnce(ValueRef<'_>) -> Result<T, E>,
204    {
205        if let Some(v) = self.get_aux(arg)? {
206            Ok(v)
207        } else {
208            let vr = self.get_raw(arg as usize);
209            self.set_aux(
210                arg,
211                func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
212            )
213        }
214    }
215
216    /// Sets the auxiliary data associated with a particular parameter. See
217    /// `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
218    /// this feature, or the unit tests of this module for an example.
219    ///
220    /// # Failure
221    ///
222    /// Will panic if `arg` is greater than or equal to
223    /// [`self.len()`](Context::len).
224    pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
225        assert!(arg < self.len() as i32);
226        let orig: Arc<T> = Arc::new(value);
227        let inner: AuxInner = orig.clone();
228        let outer = Box::new(inner);
229        let raw: *mut AuxInner = Box::into_raw(outer);
230        unsafe {
231            ffi::sqlite3_set_auxdata(
232                self.ctx,
233                arg,
234                raw.cast(),
235                Some(free_boxed_value::<AuxInner>),
236            );
237        };
238        Ok(orig)
239    }
240
241    /// Gets the auxiliary data that was associated with a given parameter via
242    /// [`set_aux`](Context::set_aux). Returns `Ok(None)` if no data has been
243    /// associated, and Ok(Some(v)) if it has. Returns an error if the
244    /// requested type does not match.
245    ///
246    /// # Failure
247    ///
248    /// Will panic if `arg` is greater than or equal to
249    /// [`self.len()`](Context::len).
250    pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
251        assert!(arg < self.len() as i32);
252        let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
253        if p.is_null() {
254            Ok(None)
255        } else {
256            let v: AuxInner = AuxInner::clone(unsafe { &*p });
257            v.downcast::<T>()
258                .map(Some)
259                .map_err(|_| Error::GetAuxWrongType)
260        }
261    }
262
263    /// Get the db connection handle via [sqlite3_context_db_handle](https://www.sqlite.org/c3ref/context_db_handle.html)
264    ///
265    /// # Safety
266    ///
267    /// This function is marked unsafe because there is a potential for other
268    /// references to the connection to be sent across threads, [see this comment](https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213).
269    pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> {
270        let handle = ffi::sqlite3_context_db_handle(self.ctx);
271        Ok(ConnectionRef {
272            conn: Connection::from_handle(handle)?,
273            phantom: PhantomData,
274        })
275    }
276}
277
278/// A reference to a connection handle with a lifetime bound to something.
279pub struct ConnectionRef<'ctx> {
280    // comes from Connection::from_handle(sqlite3_context_db_handle(...))
281    // and is non-owning
282    conn: Connection,
283    phantom: PhantomData<&'ctx Context<'ctx>>,
284}
285
286impl Deref for ConnectionRef<'_> {
287    type Target = Connection;
288
289    #[inline]
290    fn deref(&self) -> &Connection {
291        &self.conn
292    }
293}
294
295type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
296
297/// Subtype of an SQL function
298pub type SubType = Option<c_uint>;
299
300/// Result of an SQL function
301pub trait SqlFnOutput {
302    /// Converts Rust value to SQLite value with an optional subtype
303    fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>;
304}
305
306impl<T: ToSql> SqlFnOutput for T {
307    #[inline]
308    fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
309        ToSql::to_sql(self).map(|o| (o, None))
310    }
311}
312
313impl<T: ToSql> SqlFnOutput for (T, SubType) {
314    fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
315        ToSql::to_sql(&self.0).map(|o| (o, self.1))
316    }
317}
318
319/// n-th arg of an SQL scalar function
320pub struct SqlFnArg {
321    idx: usize,
322}
323impl ToSql for SqlFnArg {
324    fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
325        Ok(ToSqlOutput::Arg(self.idx))
326    }
327}
328
329unsafe fn sql_result<T: SqlFnOutput>(
330    ctx: *mut sqlite3_context,
331    args: &[*mut sqlite3_value],
332    r: Result<T>,
333) {
334    let t = r.as_ref().map(SqlFnOutput::to_sql);
335
336    match t {
337        Ok(Ok((ref value, sub_type))) => {
338            set_result(ctx, args, value);
339            if let Some(sub_type) = sub_type {
340                ffi::sqlite3_result_subtype(ctx, sub_type);
341            }
342        }
343        Ok(Err(err)) => report_error(ctx, &err),
344        Err(err) => report_error(ctx, err),
345    }
346}
347
348/// Aggregate is the callback interface for user-defined
349/// aggregate function.
350///
351/// `A` is the type of the aggregation context and `T` is the type of the final
352/// result. Implementations should be stateless.
353pub trait Aggregate<A, T>
354where
355    A: RefUnwindSafe + UnwindSafe,
356    T: SqlFnOutput,
357{
358    /// Initializes the aggregation context. Will be called prior to the first
359    /// call to [`step()`](Aggregate::step) to set up the context for an
360    /// invocation of the function. (Note: `init()` will not be called if
361    /// there are no rows.)
362    fn init(&self, ctx: &mut Context<'_>) -> Result<A>;
363
364    /// "step" function called once for each row in an aggregate group. May be
365    /// called 0 times if there are no rows.
366    fn step(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
367
368    /// Computes and returns the final result. Will be called exactly once for
369    /// each invocation of the function. If [`step()`](Aggregate::step) was
370    /// called at least once, will be given `Some(A)` (the same `A` as was
371    /// created by [`init`](Aggregate::init) and given to
372    /// [`step`](Aggregate::step)); if [`step()`](Aggregate::step) was not
373    /// called (because the function is running against 0 rows), will be
374    /// given `None`.
375    ///
376    /// The passed context will have no arguments.
377    fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<T>;
378}
379
380/// `WindowAggregate` is the callback interface for
381/// user-defined aggregate window function.
382#[cfg(feature = "window")]
383pub trait WindowAggregate<A, T>: Aggregate<A, T>
384where
385    A: RefUnwindSafe + UnwindSafe,
386    T: SqlFnOutput,
387{
388    /// Returns the current value of the aggregate. Unlike xFinal, the
389    /// implementation should not delete any context.
390    fn value(&self, acc: Option<&mut A>) -> Result<T>;
391
392    /// Removes a row from the current window.
393    fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
394}
395
396bitflags::bitflags! {
397    /// Function Flags.
398    /// See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html)
399    /// and [Function Flags](https://sqlite.org/c3ref/c_deterministic.html) for details.
400    #[derive(Clone, Copy, Debug)]
401    #[repr(C)]
402    pub struct FunctionFlags: c_int {
403        /// Specifies UTF-8 as the text encoding this SQL function prefers for its parameters.
404        const SQLITE_UTF8     = ffi::SQLITE_UTF8;
405        /// Specifies UTF-16 using little-endian byte order as the text encoding this SQL function prefers for its parameters.
406        const SQLITE_UTF16LE  = ffi::SQLITE_UTF16LE;
407        /// Specifies UTF-16 using big-endian byte order as the text encoding this SQL function prefers for its parameters.
408        const SQLITE_UTF16BE  = ffi::SQLITE_UTF16BE;
409        /// Specifies UTF-16 using native byte order as the text encoding this SQL function prefers for its parameters.
410        const SQLITE_UTF16    = ffi::SQLITE_UTF16;
411        /// Means that the function always gives the same output when the input parameters are the same.
412        const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; // 3.8.3
413        /// Means that the function may only be invoked from top-level SQL.
414        const SQLITE_DIRECTONLY    = 0x0000_0008_0000; // 3.30.0
415        /// Indicates to SQLite that a function may call `sqlite3_value_subtype()` to inspect the subtypes of its arguments.
416        const SQLITE_SUBTYPE       = 0x0000_0010_0000; // 3.30.0
417        /// Means that the function is unlikely to cause problems even if misused.
418        const SQLITE_INNOCUOUS     = 0x0000_0020_0000; // 3.31.0
419        /// Indicates to SQLite that a function might call `sqlite3_result_subtype()` to cause a subtype to be associated with its result.
420        const SQLITE_RESULT_SUBTYPE     = 0x0000_0100_0000; // 3.45.0
421        /// Indicates that the function is an aggregate that internally orders the values provided to the first argument.
422        const SQLITE_SELFORDER1 = 0x0000_0200_0000; // 3.47.0
423    }
424}
425
426impl Default for FunctionFlags {
427    #[inline]
428    fn default() -> Self {
429        Self::SQLITE_UTF8
430    }
431}
432
433impl Connection {
434    /// Attach a user-defined scalar function to
435    /// this database connection.
436    ///
437    /// `fn_name` is the name the function will be accessible from SQL.
438    /// `n_arg` is the number of arguments to the function. Use `-1` for a
439    /// variable number. If the function always returns the same value
440    /// given the same input, `deterministic` should be `true`.
441    ///
442    /// The function will remain available until the connection is closed or
443    /// until it is explicitly removed via
444    /// [`remove_function`](Connection::remove_function).
445    ///
446    /// # Example
447    ///
448    /// ```rust
449    /// # use rusqlite::{Connection, Result};
450    /// # use rusqlite::functions::FunctionFlags;
451    /// fn scalar_function_example(db: Connection) -> Result<()> {
452    ///     db.create_scalar_function(
453    ///         "halve",
454    ///         1,
455    ///         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
456    ///         |ctx| {
457    ///             let value = ctx.get::<f64>(0)?;
458    ///             Ok(value / 2f64)
459    ///         },
460    ///     )?;
461    ///
462    ///     let six_halved: f64 = db.query_row("SELECT halve(6)", [], |r| r.get(0))?;
463    ///     assert_eq!(six_halved, 3f64);
464    ///     Ok(())
465    /// }
466    /// ```
467    ///
468    /// # Failure
469    ///
470    /// Will return Err if the function could not be attached to the connection.
471    #[inline]
472    pub fn create_scalar_function<F, N: Name, T>(
473        &self,
474        fn_name: N,
475        n_arg: c_int,
476        flags: FunctionFlags,
477        x_func: F,
478    ) -> Result<()>
479    where
480        F: Fn(&Context<'_>) -> Result<T> + Send + 'static,
481        T: SqlFnOutput,
482    {
483        self.db
484            .borrow_mut()
485            .create_scalar_function(fn_name, n_arg, flags, x_func)
486    }
487
488    /// Attach a user-defined aggregate function to this
489    /// database connection.
490    ///
491    /// # Failure
492    ///
493    /// Will return Err if the function could not be attached to the connection.
494    #[inline]
495    pub fn create_aggregate_function<A, D, N: Name, T>(
496        &self,
497        fn_name: N,
498        n_arg: c_int,
499        flags: FunctionFlags,
500        aggr: D,
501    ) -> Result<()>
502    where
503        A: RefUnwindSafe + UnwindSafe,
504        D: Aggregate<A, T> + 'static,
505        T: SqlFnOutput,
506    {
507        self.db
508            .borrow_mut()
509            .create_aggregate_function(fn_name, n_arg, flags, aggr)
510    }
511
512    /// Attach a user-defined aggregate window function to
513    /// this database connection.
514    ///
515    /// See `https://sqlite.org/windowfunctions.html#udfwinfunc` for more
516    /// information.
517    #[cfg(feature = "window")]
518    #[inline]
519    pub fn create_window_function<A, N: Name, W, T>(
520        &self,
521        fn_name: N,
522        n_arg: c_int,
523        flags: FunctionFlags,
524        aggr: W,
525    ) -> Result<()>
526    where
527        A: RefUnwindSafe + UnwindSafe,
528        W: WindowAggregate<A, T> + 'static,
529        T: SqlFnOutput,
530    {
531        self.db
532            .borrow_mut()
533            .create_window_function(fn_name, n_arg, flags, aggr)
534    }
535
536    /// Removes a user-defined function from this
537    /// database connection.
538    ///
539    /// `fn_name` and `n_arg` should match the name and number of arguments
540    /// given to [`create_scalar_function`](Connection::create_scalar_function)
541    /// or [`create_aggregate_function`](Connection::create_aggregate_function).
542    ///
543    /// # Failure
544    ///
545    /// Will return Err if the function could not be removed.
546    #[inline]
547    pub fn remove_function<N: Name>(&self, fn_name: N, n_arg: c_int) -> Result<()> {
548        self.db.borrow_mut().remove_function(fn_name, n_arg)
549    }
550}
551
552impl InnerConnection {
553    /// ```compile_fail
554    /// use rusqlite::{functions::FunctionFlags, Connection, Result};
555    /// fn main() -> Result<()> {
556    ///     let db = Connection::open_in_memory()?;
557    ///     {
558    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
559    ///         db.create_scalar_function(
560    ///             "test",
561    ///             0,
562    ///             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
563    ///             |_| {
564    ///                 called.store(true, std::sync::atomic::Ordering::Relaxed);
565    ///                 Ok(true)
566    ///             },
567    ///         );
568    ///     }
569    ///     let result: Result<bool> = db.query_row("SELECT test()", [], |r| r.get(0));
570    ///     assert!(result?);
571    ///     Ok(())
572    /// }
573    /// ```
574    fn create_scalar_function<F, N: Name, T>(
575        &mut self,
576        fn_name: N,
577        n_arg: c_int,
578        flags: FunctionFlags,
579        x_func: F,
580    ) -> Result<()>
581    where
582        F: Fn(&Context<'_>) -> Result<T> + Send + 'static,
583        T: SqlFnOutput,
584    {
585        unsafe extern "C" fn call_boxed_closure<F, T>(
586            ctx: *mut sqlite3_context,
587            argc: c_int,
588            argv: *mut *mut sqlite3_value,
589        ) where
590            F: Fn(&Context<'_>) -> Result<T>,
591            T: SqlFnOutput,
592        {
593            let args = slice::from_raw_parts(argv, argc as usize);
594            let r = catch_unwind(|| {
595                let boxed_f: *const F = ffi::sqlite3_user_data(ctx).cast::<F>();
596                assert!(!boxed_f.is_null(), "Internal error - null function pointer");
597                let ctx = Context { ctx, args };
598                (*boxed_f)(&ctx)
599            });
600            let t = match r {
601                Err(_) => {
602                    report_error(ctx, &Error::UnwindingPanic);
603                    return;
604                }
605                Ok(r) => r,
606            };
607            sql_result(ctx, args, t);
608        }
609
610        let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
611        let c_name = fn_name.as_cstr()?;
612        let r = unsafe {
613            ffi::sqlite3_create_function_v2(
614                self.db(),
615                c_name.as_ptr(),
616                n_arg,
617                flags.bits(),
618                boxed_f.cast::<c_void>(),
619                Some(call_boxed_closure::<F, T>),
620                None,
621                None,
622                Some(free_boxed_value::<F>),
623            )
624        };
625        self.decode_result(r)
626    }
627
628    fn create_aggregate_function<A, D, N: Name, T>(
629        &mut self,
630        fn_name: N,
631        n_arg: c_int,
632        flags: FunctionFlags,
633        aggr: D,
634    ) -> Result<()>
635    where
636        A: RefUnwindSafe + UnwindSafe,
637        D: Aggregate<A, T> + 'static,
638        T: SqlFnOutput,
639    {
640        let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
641        let c_name = fn_name.as_cstr()?;
642        let r = unsafe {
643            ffi::sqlite3_create_function_v2(
644                self.db(),
645                c_name.as_ptr(),
646                n_arg,
647                flags.bits(),
648                boxed_aggr.cast::<c_void>(),
649                None,
650                Some(call_boxed_step::<A, D, T>),
651                Some(call_boxed_final::<A, D, T>),
652                Some(free_boxed_value::<D>),
653            )
654        };
655        self.decode_result(r)
656    }
657
658    #[cfg(feature = "window")]
659    fn create_window_function<A, N: Name, W, T>(
660        &mut self,
661        fn_name: N,
662        n_arg: c_int,
663        flags: FunctionFlags,
664        aggr: W,
665    ) -> Result<()>
666    where
667        A: RefUnwindSafe + UnwindSafe,
668        W: WindowAggregate<A, T> + 'static,
669        T: SqlFnOutput,
670    {
671        let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
672        let c_name = fn_name.as_cstr()?;
673        let r = unsafe {
674            ffi::sqlite3_create_window_function(
675                self.db(),
676                c_name.as_ptr(),
677                n_arg,
678                flags.bits(),
679                boxed_aggr.cast::<c_void>(),
680                Some(call_boxed_step::<A, W, T>),
681                Some(call_boxed_final::<A, W, T>),
682                Some(call_boxed_value::<A, W, T>),
683                Some(call_boxed_inverse::<A, W, T>),
684                Some(free_boxed_value::<W>),
685            )
686        };
687        self.decode_result(r)
688    }
689
690    fn remove_function<N: Name>(&mut self, fn_name: N, n_arg: c_int) -> Result<()> {
691        let c_name = fn_name.as_cstr()?;
692        let r = unsafe {
693            ffi::sqlite3_create_function_v2(
694                self.db(),
695                c_name.as_ptr(),
696                n_arg,
697                ffi::SQLITE_UTF8,
698                ptr::null_mut(),
699                None,
700                None,
701                None,
702                None,
703            )
704        };
705        self.decode_result(r)
706    }
707}
708
709unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
710    let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
711    if pac.is_null() {
712        return None;
713    }
714    Some(pac)
715}
716
717unsafe extern "C" fn call_boxed_step<A, D, T>(
718    ctx: *mut sqlite3_context,
719    argc: c_int,
720    argv: *mut *mut sqlite3_value,
721) where
722    A: RefUnwindSafe + UnwindSafe,
723    D: Aggregate<A, T>,
724    T: SqlFnOutput,
725{
726    let Some(pac) = aggregate_context(ctx, size_of::<*mut A>()) else {
727        ffi::sqlite3_result_error_nomem(ctx);
728        return;
729    };
730
731    let r = catch_unwind(|| {
732        let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
733        assert!(
734            !boxed_aggr.is_null(),
735            "Internal error - null aggregate pointer"
736        );
737        let mut ctx = Context {
738            ctx,
739            args: slice::from_raw_parts(argv, argc as usize),
740        };
741
742        #[expect(clippy::unnecessary_cast)]
743        if (*pac as *mut A).is_null() {
744            *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?));
745        }
746
747        (*boxed_aggr).step(&mut ctx, &mut **pac)
748    });
749    let r = match r {
750        Err(_) => {
751            report_error(ctx, &Error::UnwindingPanic);
752            return;
753        }
754        Ok(r) => r,
755    };
756    match r {
757        Ok(_) => {}
758        Err(err) => report_error(ctx, &err),
759    }
760}
761
762#[cfg(feature = "window")]
763unsafe extern "C" fn call_boxed_inverse<A, W, T>(
764    ctx: *mut sqlite3_context,
765    argc: c_int,
766    argv: *mut *mut sqlite3_value,
767) where
768    A: RefUnwindSafe + UnwindSafe,
769    W: WindowAggregate<A, T>,
770    T: SqlFnOutput,
771{
772    let Some(pac) = aggregate_context(ctx, size_of::<*mut A>()) else {
773        ffi::sqlite3_result_error_nomem(ctx);
774        return;
775    };
776
777    let r = catch_unwind(|| {
778        let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
779        assert!(
780            !boxed_aggr.is_null(),
781            "Internal error - null aggregate pointer"
782        );
783        let mut ctx = Context {
784            ctx,
785            args: slice::from_raw_parts(argv, argc as usize),
786        };
787        (*boxed_aggr).inverse(&mut ctx, &mut **pac)
788    });
789    let r = match r {
790        Err(_) => {
791            report_error(ctx, &Error::UnwindingPanic);
792            return;
793        }
794        Ok(r) => r,
795    };
796    match r {
797        Ok(_) => {}
798        Err(err) => report_error(ctx, &err),
799    }
800}
801
802unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
803where
804    A: RefUnwindSafe + UnwindSafe,
805    D: Aggregate<A, T>,
806    T: SqlFnOutput,
807{
808    // Within the xFinal callback, it is customary to set N=0 in calls to
809    // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
810    let a: Option<A> = match aggregate_context(ctx, 0) {
811        Some(pac) =>
812        {
813            #[expect(clippy::unnecessary_cast)]
814            if (*pac as *mut A).is_null() {
815                None
816            } else {
817                let a = Box::from_raw(*pac);
818                Some(*a)
819            }
820        }
821        None => None,
822    };
823
824    let r = catch_unwind(|| {
825        let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
826        assert!(
827            !boxed_aggr.is_null(),
828            "Internal error - null aggregate pointer"
829        );
830        let mut ctx = Context { ctx, args: &mut [] };
831        (*boxed_aggr).finalize(&mut ctx, a)
832    });
833    let t = match r {
834        Err(_) => {
835            report_error(ctx, &Error::UnwindingPanic);
836            return;
837        }
838        Ok(r) => r,
839    };
840    sql_result(ctx, &[], t);
841}
842
843#[cfg(feature = "window")]
844unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
845where
846    A: RefUnwindSafe + UnwindSafe,
847    W: WindowAggregate<A, T>,
848    T: SqlFnOutput,
849{
850    // Within the xValue callback, it is customary to set N=0 in calls to
851    // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
852    let pac = aggregate_context(ctx, 0).filter(|&pac| {
853        #[expect(clippy::unnecessary_cast)]
854        !(*pac as *mut A).is_null()
855    });
856
857    let r = catch_unwind(|| {
858        let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
859        assert!(
860            !boxed_aggr.is_null(),
861            "Internal error - null aggregate pointer"
862        );
863        (*boxed_aggr).value(pac.map(|pac| &mut **pac))
864    });
865    let t = match r {
866        Err(_) => {
867            report_error(ctx, &Error::UnwindingPanic);
868            return;
869        }
870        Ok(r) => r,
871    };
872    sql_result(ctx, &[], t);
873}
874
875#[cfg(all(test, not(miri)))]
876mod test {
877    #[cfg(all(target_family = "wasm", target_os = "unknown"))]
878    use wasm_bindgen_test::wasm_bindgen_test as test;
879
880    #[cfg(feature = "window")]
881    use crate::functions::WindowAggregate;
882    use crate::functions::{Aggregate, Context, FunctionFlags, SqlFnArg, SubType};
883    use crate::{Connection, Error, Result};
884    use regex::Regex;
885    use std::ffi::c_double;
886
887    fn half(ctx: &Context<'_>) -> Result<c_double> {
888        assert!(!ctx.is_empty());
889        assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
890        assert!(unsafe {
891            ctx.get_connection()
892                .as_ref()
893                .map(::std::ops::Deref::deref)
894                .is_ok()
895        });
896        let value = ctx.get::<c_double>(0)?;
897        Ok(value / 2f64)
898    }
899
900    #[test]
901    fn test_function_half() -> Result<()> {
902        let db = Connection::open_in_memory()?;
903        db.create_scalar_function(
904            c"half",
905            1,
906            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
907            half,
908        )?;
909        let result: f64 = db.one_column("SELECT half(6)", [])?;
910
911        assert!((3f64 - result).abs() < f64::EPSILON);
912        Ok(())
913    }
914
915    #[test]
916    fn test_remove_function() -> Result<()> {
917        let db = Connection::open_in_memory()?;
918        db.create_scalar_function(
919            c"half",
920            1,
921            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
922            half,
923        )?;
924        assert!((3f64 - db.one_column::<f64, _>("SELECT half(6)", [])?).abs() < f64::EPSILON);
925
926        db.remove_function(c"half", 1)?;
927        db.one_column::<f64, _>("SELECT half(6)", []).unwrap_err();
928        Ok(())
929    }
930
931    // This implementation of a regexp scalar function uses SQLite's auxiliary data
932    // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular
933    // expression multiple times within one query.
934    fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<bool> {
935        assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
936        type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
937        let regexp: std::sync::Arc<Regex> = ctx
938            .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
939                Ok(Regex::new(vr.as_str()?)?)
940            })?;
941
942        let is_match = {
943            let text = ctx
944                .get_raw(1)
945                .as_str()
946                .map_err(|e| Error::UserFunctionError(e.into()))?;
947
948            regexp.is_match(text)
949        };
950
951        Ok(is_match)
952    }
953
954    #[test]
955    fn test_function_regexp_with_auxiliary() -> Result<()> {
956        let db = Connection::open_in_memory()?;
957        db.execute_batch(
958            "BEGIN;
959             CREATE TABLE foo (x string);
960             INSERT INTO foo VALUES ('lisa');
961             INSERT INTO foo VALUES ('lXsi');
962             INSERT INTO foo VALUES ('lisX');
963             END;",
964        )?;
965        db.create_scalar_function(
966            c"regexp",
967            2,
968            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
969            regexp_with_auxiliary,
970        )?;
971
972        assert!(db.one_column::<bool, _>("SELECT regexp('l.s[aeiouy]', 'lisa')", [])?);
973
974        assert_eq!(
975            2,
976            db.one_column::<i64, _>(
977                "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
978                [],
979            )?
980        );
981        Ok(())
982    }
983
984    #[test]
985    fn test_varargs_function() -> Result<()> {
986        let db = Connection::open_in_memory()?;
987        db.create_scalar_function(
988            c"my_concat",
989            -1,
990            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
991            |ctx| {
992                let mut ret = String::new();
993
994                for idx in 0..ctx.len() {
995                    let s = ctx.get::<String>(idx)?;
996                    ret.push_str(&s);
997                }
998
999                Ok(ret)
1000            },
1001        )?;
1002
1003        for &(expected, query) in &[
1004            ("", "SELECT my_concat()"),
1005            ("onetwo", "SELECT my_concat('one', 'two')"),
1006            ("abc", "SELECT my_concat('a', 'b', 'c')"),
1007        ] {
1008            assert_eq!(expected, db.one_column::<String, _>(query, [])?);
1009        }
1010        Ok(())
1011    }
1012
1013    #[test]
1014    fn test_get_aux_type_checking() -> Result<()> {
1015        let db = Connection::open_in_memory()?;
1016        db.create_scalar_function(c"example", 2, FunctionFlags::default(), |ctx| {
1017            if !ctx.get::<bool>(1)? {
1018                ctx.set_aux::<i64>(0, 100)?;
1019            } else {
1020                assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
1021                assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100);
1022            }
1023            Ok(true)
1024        })?;
1025
1026        let res: bool = db.query_row(
1027            "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
1028            [],
1029            |r| r.get(0),
1030        )?;
1031        // Doesn't actually matter, we'll assert in the function if there's a problem.
1032        assert!(res);
1033        Ok(())
1034    }
1035
1036    struct Sum;
1037    struct Count;
1038
1039    impl Aggregate<i64, Option<i64>> for Sum {
1040        fn init(&self, _: &mut Context<'_>) -> Result<i64> {
1041            Ok(0)
1042        }
1043
1044        fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1045            *sum += ctx.get::<i64>(0)?;
1046            Ok(())
1047        }
1048
1049        fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> {
1050            Ok(sum)
1051        }
1052    }
1053
1054    impl Aggregate<i64, i64> for Count {
1055        fn init(&self, _: &mut Context<'_>) -> Result<i64> {
1056            Ok(0)
1057        }
1058
1059        fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1060            *sum += 1;
1061            Ok(())
1062        }
1063
1064        fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> {
1065            Ok(sum.unwrap_or(0))
1066        }
1067    }
1068
1069    #[test]
1070    fn test_sum() -> Result<()> {
1071        let db = Connection::open_in_memory()?;
1072        db.create_aggregate_function(
1073            c"my_sum",
1074            1,
1075            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1076            Sum,
1077        )?;
1078
1079        // sum should return NULL when given no columns (contrast with count below)
1080        let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1081        assert!(db.one_column::<Option<i64>, _>(no_result, [])?.is_none());
1082
1083        let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1084        assert_eq!(4, db.one_column::<i64, _>(single_sum, [])?);
1085
1086        let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
1087                        2, 1)";
1088        let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?;
1089        assert_eq!((4, 2), result);
1090        Ok(())
1091    }
1092
1093    #[test]
1094    fn test_count() -> Result<()> {
1095        let db = Connection::open_in_memory()?;
1096        db.create_aggregate_function(
1097            c"my_count",
1098            -1,
1099            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1100            Count,
1101        )?;
1102
1103        // count should return 0 when given no columns (contrast with sum above)
1104        let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1105        assert_eq!(db.one_column::<i64, _>(no_result, [])?, 0);
1106
1107        let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1108        assert_eq!(2, db.one_column::<i64, _>(single_sum, [])?);
1109        Ok(())
1110    }
1111
1112    #[cfg(feature = "window")]
1113    impl WindowAggregate<i64, Option<i64>> for Sum {
1114        fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1115            *sum -= ctx.get::<i64>(0)?;
1116            Ok(())
1117        }
1118
1119        fn value(&self, sum: Option<&mut i64>) -> Result<Option<i64>> {
1120            Ok(sum.copied())
1121        }
1122    }
1123
1124    #[test]
1125    #[cfg(feature = "window")]
1126    fn test_window() -> Result<()> {
1127        use fallible_iterator::FallibleIterator as _;
1128
1129        let db = Connection::open_in_memory()?;
1130        db.create_window_function(
1131            c"sumint",
1132            1,
1133            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1134            Sum,
1135        )?;
1136        db.execute_batch(
1137            "CREATE TABLE t3(x, y);
1138             INSERT INTO t3 VALUES('a', 4),
1139                     ('b', 5),
1140                     ('c', 3),
1141                     ('d', 8),
1142                     ('e', 1);",
1143        )?;
1144
1145        let mut stmt = db.prepare(
1146            "SELECT x, sumint(y) OVER (
1147                   ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1148                 ) AS sum_y
1149                 FROM t3 ORDER BY x;",
1150        )?;
1151
1152        let results: Vec<(String, i64)> = stmt
1153            .query([])?
1154            .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1155            .collect()?;
1156        let expected = vec![
1157            ("a".to_owned(), 9),
1158            ("b".to_owned(), 12),
1159            ("c".to_owned(), 16),
1160            ("d".to_owned(), 12),
1161            ("e".to_owned(), 9),
1162        ];
1163        assert_eq!(expected, results);
1164        Ok(())
1165    }
1166
1167    #[test]
1168    fn test_sub_type() -> Result<()> {
1169        fn test_getsubtype(ctx: &Context<'_>) -> Result<i32> {
1170            Ok(ctx.get_subtype(0) as i32)
1171        }
1172        fn test_setsubtype(ctx: &Context<'_>) -> Result<(SqlFnArg, SubType)> {
1173            use std::ffi::c_uint;
1174            let value = ctx.get_arg(0);
1175            let sub_type = ctx.get::<c_uint>(1)?;
1176            Ok((value, Some(sub_type)))
1177        }
1178        let db = Connection::open_in_memory()?;
1179        db.create_scalar_function(
1180            c"test_getsubtype",
1181            1,
1182            FunctionFlags::SQLITE_UTF8,
1183            test_getsubtype,
1184        )?;
1185        db.create_scalar_function(
1186            c"test_setsubtype",
1187            2,
1188            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_RESULT_SUBTYPE,
1189            test_setsubtype,
1190        )?;
1191        let result: i32 = db.one_column("SELECT test_getsubtype('hello');", [])?;
1192        assert_eq!(0, result);
1193
1194        let result: i32 =
1195            db.one_column("SELECT test_getsubtype(test_setsubtype('hello',123));", [])?;
1196        assert_eq!(123, result);
1197
1198        Ok(())
1199    }
1200
1201    #[test]
1202    fn test_blob() -> Result<()> {
1203        fn test_len(ctx: &Context<'_>) -> Result<u32> {
1204            let blob = ctx.get_raw(0);
1205            Ok(blob
1206                .as_bytes_or_null()?
1207                .map_or(0, |b| b.len().try_into().unwrap()))
1208        }
1209        let db = Connection::open_in_memory()?;
1210        db.create_scalar_function("test_len", 1, FunctionFlags::SQLITE_DETERMINISTIC, test_len)?;
1211        assert_eq!(
1212            6,
1213            db.one_column::<u32, _>("SELECT test_len(X'53514C697465');", [])?
1214        );
1215        assert_eq!(0, db.one_column::<u32, _>("SELECT test_len(X'');", [])?);
1216        assert_eq!(0, db.one_column::<u32, _>("SELECT test_len(NULL);", [])?);
1217        Ok(())
1218    }
1219
1220    #[test]
1221    #[cfg(feature = "pointer")]
1222    fn test_rc_pointer() -> Result<()> {
1223        use crate::types::ToSqlOutput;
1224        use std::ops::Deref as _;
1225        use std::rc::Rc;
1226
1227        const PTR_TYPE: &std::ffi::CStr = c"my_rust_ptr";
1228        let rc = Rc::new(1);
1229        {
1230            let ptr = ToSqlOutput::from_rc(rc.clone(), PTR_TYPE);
1231            assert_eq!(2, Rc::strong_count(&rc));
1232            fn myfunc(ctx: &Context<'_>) -> Result<ToSqlOutput<'static>> {
1233                let x = unsafe { ctx.get_pointer(0, PTR_TYPE) };
1234                assert_eq!(x, Some(&1));
1235                Ok(ToSqlOutput::from_rc(Rc::new(*x.unwrap()), PTR_TYPE))
1236            }
1237            let db = Connection::open_in_memory()?;
1238            db.create_scalar_function("myfunc", 1, FunctionFlags::SQLITE_DETERMINISTIC, myfunc)?;
1239            let mut stmt = db.prepare("SELECT myfunc(?)")?;
1240            let result = stmt.query_one([ptr], |r| {
1241                unsafe { r.get_pointer::<_, i32>(0, PTR_TYPE) }.map(|opt| opt.cloned())
1242            })?;
1243            assert_eq!(result.unwrap(), *rc.deref());
1244        }
1245        assert_eq!(1, Rc::strong_count(&rc));
1246        Ok(())
1247    }
1248
1249    #[test]
1250    #[cfg(feature = "pointer")]
1251    fn test_box_pointer() -> Result<()> {
1252        use crate::types::ToSqlOutput;
1253
1254        const PTR_TYPE: &std::ffi::CStr = c"my_rust_ptr";
1255        let value = 1;
1256        let ptr = ToSqlOutput::new_boxed(value, PTR_TYPE);
1257        fn myfunc(ctx: &Context<'_>) -> Result<ToSqlOutput<'static>> {
1258            let x = unsafe { ctx.get_pointer(0, PTR_TYPE) };
1259            assert_eq!(x, Some(&1));
1260            Ok(ToSqlOutput::new_boxed(*x.unwrap(), PTR_TYPE))
1261        }
1262        let db = Connection::open_in_memory()?;
1263        db.create_scalar_function("myfunc", 1, FunctionFlags::SQLITE_DETERMINISTIC, myfunc)?;
1264        let mut stmt = db.prepare("SELECT myfunc(?)")?;
1265        let result = stmt.query_one([ptr], |r| {
1266            unsafe { r.get_pointer::<_, i32>(0, PTR_TYPE) }.map(|opt| opt.cloned())
1267        })?;
1268        assert_eq!(result.unwrap(), value);
1269        Ok(())
1270    }
1271}