sqlite3_ext/function/mod.rs
1//! Create application-defined functions.
2//!
3//! The functionality in this module is primarily exposed through
4//! [Connection::create_scalar_function] and [Connection::create_aggregate_function].
5use super::{ffi, sqlite3_match_version, types::*, value::*, Connection, RiskLevel};
6pub use context::*;
7use std::{cmp::Ordering, ffi::CString, ptr::null_mut};
8
9mod context;
10mod stubs;
11mod test;
12
13/// Constructor for aggregate functions.
14///
15/// Aggregate functions are instantiated using user data provided when the function is
16/// registered. There is a blanket implementation for types implementing [Default] for cases
17/// where user data is not required.
18pub trait FromUserData<T> {
19 /// Construct a new instance based on the provided user data.
20 fn from_user_data(data: &T) -> Self;
21}
22
23/// Trait for scalar functions. This trait is used with
24/// [Connection::create_scalar_function_object] to implement scalar functions that have a
25/// lifetime smaller than `'static`. It is also possible to use closures and avoid implementing
26/// this trait, see [Connection::create_scalar_function] for details.
27pub trait ScalarFunction<'db> {
28 /// Perform a single invocation. The function will be invoked with a [Context] and an
29 /// array of [ValueRef] objects. The function is required to set its output using
30 /// [Context::set_result]. If no result is set, SQL NULL is returned. If the function
31 /// returns an Err value, the SQL statement will fail, even if a result had been set.
32 fn call(&self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
33}
34
35struct ScalarClosure<F>(F)
36where
37 F: Fn(&Context, &mut [&mut ValueRef]) -> Result<()> + 'static;
38
39impl<F> ScalarFunction<'_> for ScalarClosure<F>
40where
41 F: Fn(&Context, &mut [&mut ValueRef]) -> Result<()> + 'static,
42{
43 fn call(&self, ctx: &Context, args: &mut [&mut ValueRef]) -> Result<()> {
44 self.0(ctx, args)
45 }
46}
47
48/// Implement an application-defined aggregate function which cannot be used as a window
49/// function.
50///
51/// In general, there is no reason to implement this trait instead of [AggregateFunction],
52/// because the latter provides a blanket implementation of the former.
53pub trait LegacyAggregateFunction<UserData>: FromUserData<UserData> {
54 /// Assign the default value of the aggregate function to the context using
55 /// [Context::set_result].
56 ///
57 /// This method is called when the aggregate function is invoked over an empty set of
58 /// rows. The default implementation is equivalent to
59 /// `Self::from_user_data(user_data).value(context)`.
60 fn default_value(user_data: &UserData, context: &Context) -> Result<()>
61 where
62 Self: Sized,
63 {
64 Self::from_user_data(user_data).value(context)
65 }
66
67 /// Add a new row to the aggregate.
68 fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
69
70 /// Assign the current value of the aggregate function to the context using
71 /// [Context::set_result]. If no result is set, SQL NULL is returned. If the function returns
72 /// an Err value, the SQL statement will fail, even if a result had been set before the
73 /// failure.
74 fn value(&self, context: &Context) -> Result<()>;
75}
76
77/// Implement an application-defined aggregate window function.
78///
79/// The function can be registered with a database connection using
80/// [Connection::create_aggregate_function].
81pub trait AggregateFunction<UserData>: FromUserData<UserData> {
82 /// Assign the default value of the aggregate function to the context using
83 /// [Context::set_result].
84 ///
85 /// This method is called when the aggregate function is invoked over an empty set of
86 /// rows. The default implementation is equivalent to
87 /// `Self::from_user_data(user_data).value(context)`.
88 fn default_value(user_data: &UserData, context: &Context) -> Result<()>
89 where
90 Self: Sized,
91 {
92 Self::from_user_data(user_data).value(context)
93 }
94
95 /// Add a new row to the aggregate.
96 fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
97
98 /// Assign the current value of the aggregate function to the context using
99 /// [Context::set_result]. If no result is set, SQL NULL is returned. If the function returns
100 /// an Err value, the SQL statement will fail, even if a result had been set before the
101 /// failure.
102 fn value(&self, context: &Context) -> Result<()>;
103
104 /// Remove the oldest presently aggregated row.
105 ///
106 /// The args are the same that were passed to [AggregateFunction::step] when this row
107 /// was added.
108 fn inverse(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
109}
110
111impl<U, F: Default> FromUserData<U> for F {
112 fn from_user_data(_: &U) -> F {
113 F::default()
114 }
115}
116
117impl<U, T: AggregateFunction<U>> LegacyAggregateFunction<U> for T {
118 fn default_value(user_data: &U, context: &Context) -> Result<()> {
119 <T as AggregateFunction<U>>::default_value(user_data, context)
120 }
121
122 fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()> {
123 <T as AggregateFunction<U>>::step(self, context, args)
124 }
125
126 fn value(&self, context: &Context) -> Result<()> {
127 <T as AggregateFunction<U>>::value(self, context)
128 }
129}
130
131#[derive(Debug, Clone)]
132pub struct FunctionOptions {
133 n_args: i32,
134 flags: i32,
135}
136
137impl Default for FunctionOptions {
138 fn default() -> Self {
139 FunctionOptions::default()
140 }
141}
142
143impl FunctionOptions {
144 pub const fn default() -> Self {
145 FunctionOptions {
146 n_args: -1,
147 flags: 0,
148 }
149 }
150
151 /// Set the number of parameters accepted by this function. Multiple functions may be
152 /// provided under the same name with different n_args values; the implementation will
153 /// be chosen by SQLite based on the number of parameters at the call site. The value
154 /// may also be -1, which means that the function accepts any number of parameters.
155 /// Functions which take a specific number of parameters take precedence over functions
156 /// which take any number.
157 ///
158 /// # Panics
159 ///
160 /// This function panics if n_args is outside the range -1..128. This limitation is
161 /// imposed by SQLite.
162 pub const fn set_n_args(mut self, n_args: i32) -> Self {
163 assert!(n_args >= -1 && n_args < 128, "n_args invalid");
164 self.n_args = n_args;
165 self
166 }
167
168 /// Enable or disable the deterministic flag. This flag indicates that the function is
169 /// pure. It must have no side effects and the value must be determined solely its the
170 /// parameters.
171 ///
172 /// The SQLite query planner is able to perform additional optimizations on
173 /// deterministic functions, so use of this flag is recommended where possible.
174 pub const fn set_deterministic(mut self, val: bool) -> Self {
175 if val {
176 self.flags |= ffi::SQLITE_DETERMINISTIC;
177 } else {
178 self.flags &= !ffi::SQLITE_DETERMINISTIC;
179 }
180 self
181 }
182
183 /// Set the level of risk for this function. See the [RiskLevel] enum for details about
184 /// what the individual options mean.
185 ///
186 /// Requires SQLite 3.31.0. On earlier versions of SQLite, this function is a harmless no-op.
187 pub const fn set_risk_level(
188 #[cfg_attr(not(modern_sqlite), allow(unused_mut))] mut self,
189 level: RiskLevel,
190 ) -> Self {
191 let _ = level;
192 #[cfg(modern_sqlite)]
193 {
194 self.flags |= match level {
195 RiskLevel::Innocuous => ffi::SQLITE_INNOCUOUS,
196 RiskLevel::DirectOnly => ffi::SQLITE_DIRECTONLY,
197 };
198 self.flags &= match level {
199 RiskLevel::Innocuous => !ffi::SQLITE_DIRECTONLY,
200 RiskLevel::DirectOnly => !ffi::SQLITE_INNOCUOUS,
201 };
202 }
203 self
204 }
205}
206
207impl Connection {
208 /// Create a stub function that always fails.
209 ///
210 /// This API makes sure a global version of a function with a particular name and
211 /// number of parameters exists. If no such function exists before this API is called,
212 /// a new function is created. The implementation of the new function always causes an
213 /// exception to be thrown. So the new function is not good for anything by itself. Its
214 /// only purpose is to be a placeholder function that can be overloaded by a virtual
215 /// table.
216 ///
217 /// For more information, see [vtab::FindFunctionVTab](super::vtab::FindFunctionVTab).
218 pub fn create_overloaded_function(&self, name: &str, opts: &FunctionOptions) -> Result<()> {
219 let guard = self.lock();
220 let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
221 unsafe {
222 Error::from_sqlite_desc(
223 ffi::sqlite3_overload_function(self.as_mut_ptr(), name.as_ptr() as _, opts.n_args),
224 guard,
225 )
226 }
227 }
228
229 /// Create a new scalar function. The function will be invoked with a [Context] and an array of
230 /// [ValueRef] objects. The function is required to set its output using [Context::set_result].
231 /// If no result is set, SQL NULL is returned. If the function returns an Err value, the SQL
232 /// statement will fail, even if a result had been set.
233 ///
234 /// The passed function can be a closure, however the lifetime of the closure must be
235 /// `'static` due to limitations in the Rust borrow checker. The
236 /// [Self::create_scalar_function_object] function is an alternative that allows using an
237 /// alternative lifetime.
238 ///
239 /// # Compatibility
240 ///
241 /// On versions of SQLite earlier than 3.7.3, this function will leak the function and
242 /// all bound variables. This is because these versions of SQLite did not provide the
243 /// ability to specify a destructor function.
244 pub fn create_scalar_function<F>(
245 &self,
246 name: &str,
247 opts: &FunctionOptions,
248 func: F,
249 ) -> Result<()>
250 where
251 F: Fn(&Context, &mut [&mut ValueRef]) -> Result<()> + 'static,
252 {
253 self.create_scalar_function_object(name, &opts, ScalarClosure(func))
254 }
255
256 /// Create a new scalar function using a struct. This function is identical to
257 /// [Self::create_scalar_function], but uses a trait object instead of a closure. This enables
258 /// creating scalar functions that maintain references with a lifetime smaller than `'static`.
259 pub fn create_scalar_function_object<'db, F>(
260 &'db self,
261 name: &str,
262 opts: &FunctionOptions,
263 func: F,
264 ) -> Result<()>
265 where
266 F: ScalarFunction<'db>,
267 {
268 let guard = self.lock();
269 let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
270 let func = Box::new(func);
271 unsafe {
272 Error::from_sqlite_desc(
273 sqlite3_match_version! {
274 3_007_003 => ffi::sqlite3_create_function_v2(
275 self.as_mut_ptr(),
276 name.as_ptr() as _,
277 opts.n_args,
278 opts.flags,
279 Box::into_raw(func) as _,
280 Some(stubs::call_scalar::<F>),
281 None,
282 None,
283 Some(ffi::drop_boxed::<F>),
284 ),
285 _ => ffi::sqlite3_create_function(
286 self.as_mut_ptr(),
287 name.as_ptr() as _,
288 opts.n_args,
289 opts.flags,
290 Box::into_raw(func) as _,
291 Some(stubs::call_scalar::<F>),
292 None,
293 None,
294 ),
295 },
296 guard,
297 )
298 }
299 }
300
301 /// Create a new aggregate function which cannot be used as a window function.
302 ///
303 /// In general, you should use
304 /// [create_aggregate_function](Connection::create_aggregate_function) instead, which
305 /// provides all of the same features as legacy aggregate functions but also support
306 /// WINDOW.
307 ///
308 /// # Compatibility
309 ///
310 /// On versions of SQLite earlier than 3.7.3, this function will leak the user data.
311 /// This is because these versions of SQLite did not provide the ability to specify a
312 /// destructor function.
313 pub fn create_legacy_aggregate_function<U, F: LegacyAggregateFunction<U>>(
314 &self,
315 name: &str,
316 opts: &FunctionOptions,
317 user_data: U,
318 ) -> Result<()> {
319 let guard = self.lock();
320 let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
321 let user_data = Box::new(user_data);
322 unsafe {
323 Error::from_sqlite_desc(
324 sqlite3_match_version! {
325 3_007_003 => ffi::sqlite3_create_function_v2(
326 self.as_mut_ptr(),
327 name.as_ptr() as _,
328 opts.n_args,
329 opts.flags,
330 Box::into_raw(user_data) as _,
331 None,
332 Some(stubs::aggregate_step::<U, F>),
333 Some(stubs::aggregate_final::<U, F>),
334 Some(ffi::drop_boxed::<U>),
335 ),
336 _ => ffi::sqlite3_create_function(
337 self.as_mut_ptr(),
338 name.as_ptr() as _,
339 opts.n_args,
340 opts.flags,
341 Box::into_raw(user_data) as _,
342 None,
343 Some(stubs::aggregate_step::<U, F>),
344 Some(stubs::aggregate_final::<U, F>),
345 ),
346 },
347 guard,
348 )
349 }
350 }
351
352 /// Create a new aggregate function.
353 ///
354 /// # Compatibility
355 ///
356 /// Window functions require SQLite 3.25.0. On earlier versions of SQLite, this
357 /// function will automatically fall back to
358 /// [create_legacy_aggregate_function](Connection::create_legacy_aggregate_function).
359 pub fn create_aggregate_function<U, F: AggregateFunction<U>>(
360 &self,
361 name: &str,
362 opts: &FunctionOptions,
363 user_data: U,
364 ) -> Result<()> {
365 sqlite3_match_version! {
366 3_025_000 => {
367 let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
368 let user_data = Box::new(user_data);
369 let guard = self.lock();
370 unsafe {
371 Error::from_sqlite_desc(ffi::sqlite3_create_window_function(
372 self.as_mut_ptr(),
373 name.as_ptr() as _,
374 opts.n_args,
375 opts.flags,
376 Box::into_raw(user_data) as _,
377 Some(stubs::aggregate_step::<U, F>),
378 Some(stubs::aggregate_final::<U, F>),
379 Some(stubs::aggregate_value::<U, F>),
380 Some(stubs::aggregate_inverse::<U, F>),
381 Some(ffi::drop_boxed::<U>),
382 ), guard)
383 }
384 },
385 _ => self.create_legacy_aggregate_function::<U, F>(name, opts, user_data),
386 }
387 }
388
389 /// Remove an application-defined scalar or aggregate function. The name and n_args
390 /// parameters must match the values used when the function was created.
391 pub fn remove_function(&self, name: &str, n_args: i32) -> Result<()> {
392 let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
393 let guard = self.lock();
394 unsafe {
395 Error::from_sqlite_desc(
396 ffi::sqlite3_create_function(
397 self.as_mut_ptr(),
398 name.as_ptr() as _,
399 n_args,
400 0,
401 null_mut(),
402 None,
403 None,
404 None,
405 ),
406 guard,
407 )
408 }
409 }
410
411 /// Register a new collating sequence.
412 pub fn create_collation<F: Fn(&str, &str) -> Ordering>(
413 &self,
414 name: &str,
415 func: F,
416 ) -> Result<()> {
417 let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
418 let func = Box::into_raw(Box::new(func));
419 let guard = self.lock();
420 unsafe {
421 let rc = ffi::sqlite3_create_collation_v2(
422 self.as_mut_ptr(),
423 name.as_ptr() as _,
424 ffi::SQLITE_UTF8,
425 func as _,
426 Some(stubs::compare::<F>),
427 Some(ffi::drop_boxed::<F>),
428 );
429 if rc != ffi::SQLITE_OK {
430 // The xDestroy callback is not called if the
431 // sqlite3_create_collation_v2() function fails.
432 drop(Box::from_raw(func));
433 }
434 Error::from_sqlite_desc(rc, guard)
435 }
436 }
437
438 /// Register a callback for when SQLite needs a collation sequence. The function will
439 /// be invoked when a collation sequence is needed, and
440 /// [create_collation](Connection::create_collation) can be used to provide the needed
441 /// sequence.
442 ///
443 /// Note: the provided function and any captured variables will be leaked. SQLite does
444 /// not provide any facilities for cleaning up this data.
445 pub fn set_collation_needed_func<F: Fn(&str)>(&self, func: F) -> Result<()> {
446 let func = Box::new(func);
447 let guard = self.lock();
448 unsafe {
449 Error::from_sqlite_desc(
450 ffi::sqlite3_collation_needed(
451 self.as_mut_ptr(),
452 Box::into_raw(func) as _,
453 Some(stubs::collation_needed::<F>),
454 ),
455 guard,
456 )
457 }
458 }
459}