1use std::any::Any;
56use std::ffi::{c_int, c_uint, c_void};
57use std::marker::PhantomData;
58use std::ops::Deref;
59use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60use std::ptr;
61use std::slice;
62use std::sync::Arc;
63
64use crate::ffi;
65use crate::ffi::sqlite3_context;
66use crate::ffi::sqlite3_value;
67
68use crate::context::set_result;
69use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};
70use crate::util::free_boxed_value;
71use crate::{str_to_cstring, Connection, Error, InnerConnection, Name, Result};
72
73unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74 if let Error::SqliteFailure(ref err, ref s) = *err {
75 ffi::sqlite3_result_error_code(ctx, err.extended_code);
76 if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
77 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
78 }
79 } else {
80 ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION);
81 if let Ok(cstr) = str_to_cstring(&err.to_string()) {
82 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
83 }
84 }
85}
86
87pub struct Context<'a> {
90 ctx: *mut sqlite3_context,
91 args: &'a [*mut sqlite3_value],
92}
93
94impl Context<'_> {
95 #[inline]
97 #[must_use]
98 pub fn len(&self) -> usize {
99 self.args.len()
100 }
101
102 #[inline]
104 #[must_use]
105 pub fn is_empty(&self) -> bool {
106 self.args.is_empty()
107 }
108
109 pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
119 let arg = self.args[idx];
120 let value = unsafe { ValueRef::from_value(arg) };
121 FromSql::column_result(value).map_err(|err| match err {
122 FromSqlError::InvalidType => {
123 Error::InvalidFunctionParameterType(idx, value.data_type())
124 }
125 FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
126 FromSqlError::Utf8Error(err) => Error::Utf8Error(idx, err),
127 FromSqlError::Other(err) => {
128 Error::FromSqlConversionFailure(idx, value.data_type(), err)
129 }
130 FromSqlError::InvalidBlobSize { .. } => {
131 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
132 }
133 })
134 }
135
136 #[cfg(feature = "pointer")]
140 pub unsafe fn get_pointer<T: 'static>(
141 &self,
142 idx: usize,
143 ptr_type: &'static std::ffi::CStr,
144 ) -> Option<&T> {
145 let arg = self.args[idx];
146 debug_assert_eq!(unsafe { ffi::sqlite3_value_type(arg) }, ffi::SQLITE_NULL);
147 unsafe {
148 ffi::sqlite3_value_pointer(arg, ptr_type.as_ptr())
149 .cast::<T>()
150 .as_ref()
151 }
152 }
153
154 #[inline]
161 #[must_use]
162 pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
163 let arg = self.args[idx];
164 unsafe { ValueRef::from_value(arg) }
165 }
166
167 #[inline]
170 #[must_use]
171 pub fn get_arg(&self, idx: usize) -> SqlFnArg {
172 assert!(idx < self.len());
173 SqlFnArg { idx }
174 }
175
176 pub fn get_subtype(&self, idx: usize) -> c_uint {
183 let arg = self.args[idx];
184 unsafe { ffi::sqlite3_value_subtype(arg) }
185 }
186
187 pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
200 where
201 T: Send + Sync + 'static,
202 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
203 F: FnOnce(ValueRef<'_>) -> Result<T, E>,
204 {
205 if let Some(v) = self.get_aux(arg)? {
206 Ok(v)
207 } else {
208 let vr = self.get_raw(arg as usize);
209 self.set_aux(
210 arg,
211 func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
212 )
213 }
214 }
215
216 pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
225 assert!(arg < self.len() as i32);
226 let orig: Arc<T> = Arc::new(value);
227 let inner: AuxInner = orig.clone();
228 let outer = Box::new(inner);
229 let raw: *mut AuxInner = Box::into_raw(outer);
230 unsafe {
231 ffi::sqlite3_set_auxdata(
232 self.ctx,
233 arg,
234 raw.cast(),
235 Some(free_boxed_value::<AuxInner>),
236 );
237 };
238 Ok(orig)
239 }
240
241 pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
251 assert!(arg < self.len() as i32);
252 let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
253 if p.is_null() {
254 Ok(None)
255 } else {
256 let v: AuxInner = AuxInner::clone(unsafe { &*p });
257 v.downcast::<T>()
258 .map(Some)
259 .map_err(|_| Error::GetAuxWrongType)
260 }
261 }
262
263 pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> {
270 let handle = ffi::sqlite3_context_db_handle(self.ctx);
271 Ok(ConnectionRef {
272 conn: Connection::from_handle(handle)?,
273 phantom: PhantomData,
274 })
275 }
276}
277
278pub struct ConnectionRef<'ctx> {
280 conn: Connection,
283 phantom: PhantomData<&'ctx Context<'ctx>>,
284}
285
286impl Deref for ConnectionRef<'_> {
287 type Target = Connection;
288
289 #[inline]
290 fn deref(&self) -> &Connection {
291 &self.conn
292 }
293}
294
295type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
296
297pub type SubType = Option<c_uint>;
299
300pub trait SqlFnOutput {
302 fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>;
304}
305
306impl<T: ToSql> SqlFnOutput for T {
307 #[inline]
308 fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
309 ToSql::to_sql(self).map(|o| (o, None))
310 }
311}
312
313impl<T: ToSql> SqlFnOutput for (T, SubType) {
314 fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
315 ToSql::to_sql(&self.0).map(|o| (o, self.1))
316 }
317}
318
319pub struct SqlFnArg {
321 idx: usize,
322}
323impl ToSql for SqlFnArg {
324 fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
325 Ok(ToSqlOutput::Arg(self.idx))
326 }
327}
328
329unsafe fn sql_result<T: SqlFnOutput>(
330 ctx: *mut sqlite3_context,
331 args: &[*mut sqlite3_value],
332 r: Result<T>,
333) {
334 let t = r.as_ref().map(SqlFnOutput::to_sql);
335
336 match t {
337 Ok(Ok((ref value, sub_type))) => {
338 set_result(ctx, args, value);
339 if let Some(sub_type) = sub_type {
340 ffi::sqlite3_result_subtype(ctx, sub_type);
341 }
342 }
343 Ok(Err(err)) => report_error(ctx, &err),
344 Err(err) => report_error(ctx, err),
345 }
346}
347
348pub trait Aggregate<A, T>
354where
355 A: RefUnwindSafe + UnwindSafe,
356 T: SqlFnOutput,
357{
358 fn init(&self, ctx: &mut Context<'_>) -> Result<A>;
363
364 fn step(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
367
368 fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<T>;
378}
379
380#[cfg(feature = "window")]
383pub trait WindowAggregate<A, T>: Aggregate<A, T>
384where
385 A: RefUnwindSafe + UnwindSafe,
386 T: SqlFnOutput,
387{
388 fn value(&self, acc: Option<&mut A>) -> Result<T>;
391
392 fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
394}
395
396bitflags::bitflags! {
397 #[derive(Clone, Copy, Debug)]
401 #[repr(C)]
402 pub struct FunctionFlags: c_int {
403 const SQLITE_UTF8 = ffi::SQLITE_UTF8;
405 const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE;
407 const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE;
409 const SQLITE_UTF16 = ffi::SQLITE_UTF16;
411 const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; const SQLITE_DIRECTONLY = 0x0000_0008_0000; const SQLITE_SUBTYPE = 0x0000_0010_0000; const SQLITE_INNOCUOUS = 0x0000_0020_0000; const SQLITE_RESULT_SUBTYPE = 0x0000_0100_0000; const SQLITE_SELFORDER1 = 0x0000_0200_0000; }
424}
425
426impl Default for FunctionFlags {
427 #[inline]
428 fn default() -> Self {
429 Self::SQLITE_UTF8
430 }
431}
432
433impl Connection {
434 #[inline]
472 pub fn create_scalar_function<F, N: Name, T>(
473 &self,
474 fn_name: N,
475 n_arg: c_int,
476 flags: FunctionFlags,
477 x_func: F,
478 ) -> Result<()>
479 where
480 F: Fn(&Context<'_>) -> Result<T> + Send + 'static,
481 T: SqlFnOutput,
482 {
483 self.db
484 .borrow_mut()
485 .create_scalar_function(fn_name, n_arg, flags, x_func)
486 }
487
488 #[inline]
495 pub fn create_aggregate_function<A, D, N: Name, T>(
496 &self,
497 fn_name: N,
498 n_arg: c_int,
499 flags: FunctionFlags,
500 aggr: D,
501 ) -> Result<()>
502 where
503 A: RefUnwindSafe + UnwindSafe,
504 D: Aggregate<A, T> + 'static,
505 T: SqlFnOutput,
506 {
507 self.db
508 .borrow_mut()
509 .create_aggregate_function(fn_name, n_arg, flags, aggr)
510 }
511
512 #[cfg(feature = "window")]
518 #[inline]
519 pub fn create_window_function<A, N: Name, W, T>(
520 &self,
521 fn_name: N,
522 n_arg: c_int,
523 flags: FunctionFlags,
524 aggr: W,
525 ) -> Result<()>
526 where
527 A: RefUnwindSafe + UnwindSafe,
528 W: WindowAggregate<A, T> + 'static,
529 T: SqlFnOutput,
530 {
531 self.db
532 .borrow_mut()
533 .create_window_function(fn_name, n_arg, flags, aggr)
534 }
535
536 #[inline]
547 pub fn remove_function<N: Name>(&self, fn_name: N, n_arg: c_int) -> Result<()> {
548 self.db.borrow_mut().remove_function(fn_name, n_arg)
549 }
550}
551
552impl InnerConnection {
553 fn create_scalar_function<F, N: Name, T>(
575 &mut self,
576 fn_name: N,
577 n_arg: c_int,
578 flags: FunctionFlags,
579 x_func: F,
580 ) -> Result<()>
581 where
582 F: Fn(&Context<'_>) -> Result<T> + Send + 'static,
583 T: SqlFnOutput,
584 {
585 unsafe extern "C" fn call_boxed_closure<F, T>(
586 ctx: *mut sqlite3_context,
587 argc: c_int,
588 argv: *mut *mut sqlite3_value,
589 ) where
590 F: Fn(&Context<'_>) -> Result<T>,
591 T: SqlFnOutput,
592 {
593 let args = slice::from_raw_parts(argv, argc as usize);
594 let r = catch_unwind(|| {
595 let boxed_f: *const F = ffi::sqlite3_user_data(ctx).cast::<F>();
596 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
597 let ctx = Context { ctx, args };
598 (*boxed_f)(&ctx)
599 });
600 let t = match r {
601 Err(_) => {
602 report_error(ctx, &Error::UnwindingPanic);
603 return;
604 }
605 Ok(r) => r,
606 };
607 sql_result(ctx, args, t);
608 }
609
610 let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
611 let c_name = fn_name.as_cstr()?;
612 let r = unsafe {
613 ffi::sqlite3_create_function_v2(
614 self.db(),
615 c_name.as_ptr(),
616 n_arg,
617 flags.bits(),
618 boxed_f.cast::<c_void>(),
619 Some(call_boxed_closure::<F, T>),
620 None,
621 None,
622 Some(free_boxed_value::<F>),
623 )
624 };
625 self.decode_result(r)
626 }
627
628 fn create_aggregate_function<A, D, N: Name, T>(
629 &mut self,
630 fn_name: N,
631 n_arg: c_int,
632 flags: FunctionFlags,
633 aggr: D,
634 ) -> Result<()>
635 where
636 A: RefUnwindSafe + UnwindSafe,
637 D: Aggregate<A, T> + 'static,
638 T: SqlFnOutput,
639 {
640 let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
641 let c_name = fn_name.as_cstr()?;
642 let r = unsafe {
643 ffi::sqlite3_create_function_v2(
644 self.db(),
645 c_name.as_ptr(),
646 n_arg,
647 flags.bits(),
648 boxed_aggr.cast::<c_void>(),
649 None,
650 Some(call_boxed_step::<A, D, T>),
651 Some(call_boxed_final::<A, D, T>),
652 Some(free_boxed_value::<D>),
653 )
654 };
655 self.decode_result(r)
656 }
657
658 #[cfg(feature = "window")]
659 fn create_window_function<A, N: Name, W, T>(
660 &mut self,
661 fn_name: N,
662 n_arg: c_int,
663 flags: FunctionFlags,
664 aggr: W,
665 ) -> Result<()>
666 where
667 A: RefUnwindSafe + UnwindSafe,
668 W: WindowAggregate<A, T> + 'static,
669 T: SqlFnOutput,
670 {
671 let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
672 let c_name = fn_name.as_cstr()?;
673 let r = unsafe {
674 ffi::sqlite3_create_window_function(
675 self.db(),
676 c_name.as_ptr(),
677 n_arg,
678 flags.bits(),
679 boxed_aggr.cast::<c_void>(),
680 Some(call_boxed_step::<A, W, T>),
681 Some(call_boxed_final::<A, W, T>),
682 Some(call_boxed_value::<A, W, T>),
683 Some(call_boxed_inverse::<A, W, T>),
684 Some(free_boxed_value::<W>),
685 )
686 };
687 self.decode_result(r)
688 }
689
690 fn remove_function<N: Name>(&mut self, fn_name: N, n_arg: c_int) -> Result<()> {
691 let c_name = fn_name.as_cstr()?;
692 let r = unsafe {
693 ffi::sqlite3_create_function_v2(
694 self.db(),
695 c_name.as_ptr(),
696 n_arg,
697 ffi::SQLITE_UTF8,
698 ptr::null_mut(),
699 None,
700 None,
701 None,
702 None,
703 )
704 };
705 self.decode_result(r)
706 }
707}
708
709unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
710 let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
711 if pac.is_null() {
712 return None;
713 }
714 Some(pac)
715}
716
717unsafe extern "C" fn call_boxed_step<A, D, T>(
718 ctx: *mut sqlite3_context,
719 argc: c_int,
720 argv: *mut *mut sqlite3_value,
721) where
722 A: RefUnwindSafe + UnwindSafe,
723 D: Aggregate<A, T>,
724 T: SqlFnOutput,
725{
726 let Some(pac) = aggregate_context(ctx, size_of::<*mut A>()) else {
727 ffi::sqlite3_result_error_nomem(ctx);
728 return;
729 };
730
731 let r = catch_unwind(|| {
732 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
733 assert!(
734 !boxed_aggr.is_null(),
735 "Internal error - null aggregate pointer"
736 );
737 let mut ctx = Context {
738 ctx,
739 args: slice::from_raw_parts(argv, argc as usize),
740 };
741
742 #[expect(clippy::unnecessary_cast)]
743 if (*pac as *mut A).is_null() {
744 *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?));
745 }
746
747 (*boxed_aggr).step(&mut ctx, &mut **pac)
748 });
749 let r = match r {
750 Err(_) => {
751 report_error(ctx, &Error::UnwindingPanic);
752 return;
753 }
754 Ok(r) => r,
755 };
756 match r {
757 Ok(_) => {}
758 Err(err) => report_error(ctx, &err),
759 }
760}
761
762#[cfg(feature = "window")]
763unsafe extern "C" fn call_boxed_inverse<A, W, T>(
764 ctx: *mut sqlite3_context,
765 argc: c_int,
766 argv: *mut *mut sqlite3_value,
767) where
768 A: RefUnwindSafe + UnwindSafe,
769 W: WindowAggregate<A, T>,
770 T: SqlFnOutput,
771{
772 let Some(pac) = aggregate_context(ctx, size_of::<*mut A>()) else {
773 ffi::sqlite3_result_error_nomem(ctx);
774 return;
775 };
776
777 let r = catch_unwind(|| {
778 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
779 assert!(
780 !boxed_aggr.is_null(),
781 "Internal error - null aggregate pointer"
782 );
783 let mut ctx = Context {
784 ctx,
785 args: slice::from_raw_parts(argv, argc as usize),
786 };
787 (*boxed_aggr).inverse(&mut ctx, &mut **pac)
788 });
789 let r = match r {
790 Err(_) => {
791 report_error(ctx, &Error::UnwindingPanic);
792 return;
793 }
794 Ok(r) => r,
795 };
796 match r {
797 Ok(_) => {}
798 Err(err) => report_error(ctx, &err),
799 }
800}
801
802unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
803where
804 A: RefUnwindSafe + UnwindSafe,
805 D: Aggregate<A, T>,
806 T: SqlFnOutput,
807{
808 let a: Option<A> = match aggregate_context(ctx, 0) {
811 Some(pac) =>
812 {
813 #[expect(clippy::unnecessary_cast)]
814 if (*pac as *mut A).is_null() {
815 None
816 } else {
817 let a = Box::from_raw(*pac);
818 Some(*a)
819 }
820 }
821 None => None,
822 };
823
824 let r = catch_unwind(|| {
825 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
826 assert!(
827 !boxed_aggr.is_null(),
828 "Internal error - null aggregate pointer"
829 );
830 let mut ctx = Context { ctx, args: &mut [] };
831 (*boxed_aggr).finalize(&mut ctx, a)
832 });
833 let t = match r {
834 Err(_) => {
835 report_error(ctx, &Error::UnwindingPanic);
836 return;
837 }
838 Ok(r) => r,
839 };
840 sql_result(ctx, &[], t);
841}
842
843#[cfg(feature = "window")]
844unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
845where
846 A: RefUnwindSafe + UnwindSafe,
847 W: WindowAggregate<A, T>,
848 T: SqlFnOutput,
849{
850 let pac = aggregate_context(ctx, 0).filter(|&pac| {
853 #[expect(clippy::unnecessary_cast)]
854 !(*pac as *mut A).is_null()
855 });
856
857 let r = catch_unwind(|| {
858 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
859 assert!(
860 !boxed_aggr.is_null(),
861 "Internal error - null aggregate pointer"
862 );
863 (*boxed_aggr).value(pac.map(|pac| &mut **pac))
864 });
865 let t = match r {
866 Err(_) => {
867 report_error(ctx, &Error::UnwindingPanic);
868 return;
869 }
870 Ok(r) => r,
871 };
872 sql_result(ctx, &[], t);
873}
874
875#[cfg(all(test, not(miri)))]
876mod test {
877 #[cfg(all(target_family = "wasm", target_os = "unknown"))]
878 use wasm_bindgen_test::wasm_bindgen_test as test;
879
880 #[cfg(feature = "window")]
881 use crate::functions::WindowAggregate;
882 use crate::functions::{Aggregate, Context, FunctionFlags, SqlFnArg, SubType};
883 use crate::{Connection, Error, Result};
884 use regex::Regex;
885 use std::ffi::c_double;
886
887 fn half(ctx: &Context<'_>) -> Result<c_double> {
888 assert!(!ctx.is_empty());
889 assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
890 assert!(unsafe {
891 ctx.get_connection()
892 .as_ref()
893 .map(::std::ops::Deref::deref)
894 .is_ok()
895 });
896 let value = ctx.get::<c_double>(0)?;
897 Ok(value / 2f64)
898 }
899
900 #[test]
901 fn test_function_half() -> Result<()> {
902 let db = Connection::open_in_memory()?;
903 db.create_scalar_function(
904 c"half",
905 1,
906 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
907 half,
908 )?;
909 let result: f64 = db.one_column("SELECT half(6)", [])?;
910
911 assert!((3f64 - result).abs() < f64::EPSILON);
912 Ok(())
913 }
914
915 #[test]
916 fn test_remove_function() -> Result<()> {
917 let db = Connection::open_in_memory()?;
918 db.create_scalar_function(
919 c"half",
920 1,
921 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
922 half,
923 )?;
924 assert!((3f64 - db.one_column::<f64, _>("SELECT half(6)", [])?).abs() < f64::EPSILON);
925
926 db.remove_function(c"half", 1)?;
927 db.one_column::<f64, _>("SELECT half(6)", []).unwrap_err();
928 Ok(())
929 }
930
931 fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<bool> {
935 assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
936 type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
937 let regexp: std::sync::Arc<Regex> = ctx
938 .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
939 Ok(Regex::new(vr.as_str()?)?)
940 })?;
941
942 let is_match = {
943 let text = ctx
944 .get_raw(1)
945 .as_str()
946 .map_err(|e| Error::UserFunctionError(e.into()))?;
947
948 regexp.is_match(text)
949 };
950
951 Ok(is_match)
952 }
953
954 #[test]
955 fn test_function_regexp_with_auxiliary() -> Result<()> {
956 let db = Connection::open_in_memory()?;
957 db.execute_batch(
958 "BEGIN;
959 CREATE TABLE foo (x string);
960 INSERT INTO foo VALUES ('lisa');
961 INSERT INTO foo VALUES ('lXsi');
962 INSERT INTO foo VALUES ('lisX');
963 END;",
964 )?;
965 db.create_scalar_function(
966 c"regexp",
967 2,
968 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
969 regexp_with_auxiliary,
970 )?;
971
972 assert!(db.one_column::<bool, _>("SELECT regexp('l.s[aeiouy]', 'lisa')", [])?);
973
974 assert_eq!(
975 2,
976 db.one_column::<i64, _>(
977 "SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
978 [],
979 )?
980 );
981 Ok(())
982 }
983
984 #[test]
985 fn test_varargs_function() -> Result<()> {
986 let db = Connection::open_in_memory()?;
987 db.create_scalar_function(
988 c"my_concat",
989 -1,
990 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
991 |ctx| {
992 let mut ret = String::new();
993
994 for idx in 0..ctx.len() {
995 let s = ctx.get::<String>(idx)?;
996 ret.push_str(&s);
997 }
998
999 Ok(ret)
1000 },
1001 )?;
1002
1003 for &(expected, query) in &[
1004 ("", "SELECT my_concat()"),
1005 ("onetwo", "SELECT my_concat('one', 'two')"),
1006 ("abc", "SELECT my_concat('a', 'b', 'c')"),
1007 ] {
1008 assert_eq!(expected, db.one_column::<String, _>(query, [])?);
1009 }
1010 Ok(())
1011 }
1012
1013 #[test]
1014 fn test_get_aux_type_checking() -> Result<()> {
1015 let db = Connection::open_in_memory()?;
1016 db.create_scalar_function(c"example", 2, FunctionFlags::default(), |ctx| {
1017 if !ctx.get::<bool>(1)? {
1018 ctx.set_aux::<i64>(0, 100)?;
1019 } else {
1020 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
1021 assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100);
1022 }
1023 Ok(true)
1024 })?;
1025
1026 let res: bool = db.query_row(
1027 "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
1028 [],
1029 |r| r.get(0),
1030 )?;
1031 assert!(res);
1033 Ok(())
1034 }
1035
1036 struct Sum;
1037 struct Count;
1038
1039 impl Aggregate<i64, Option<i64>> for Sum {
1040 fn init(&self, _: &mut Context<'_>) -> Result<i64> {
1041 Ok(0)
1042 }
1043
1044 fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1045 *sum += ctx.get::<i64>(0)?;
1046 Ok(())
1047 }
1048
1049 fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> {
1050 Ok(sum)
1051 }
1052 }
1053
1054 impl Aggregate<i64, i64> for Count {
1055 fn init(&self, _: &mut Context<'_>) -> Result<i64> {
1056 Ok(0)
1057 }
1058
1059 fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1060 *sum += 1;
1061 Ok(())
1062 }
1063
1064 fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> {
1065 Ok(sum.unwrap_or(0))
1066 }
1067 }
1068
1069 #[test]
1070 fn test_sum() -> Result<()> {
1071 let db = Connection::open_in_memory()?;
1072 db.create_aggregate_function(
1073 c"my_sum",
1074 1,
1075 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1076 Sum,
1077 )?;
1078
1079 let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1081 assert!(db.one_column::<Option<i64>, _>(no_result, [])?.is_none());
1082
1083 let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1084 assert_eq!(4, db.one_column::<i64, _>(single_sum, [])?);
1085
1086 let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
1087 2, 1)";
1088 let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?;
1089 assert_eq!((4, 2), result);
1090 Ok(())
1091 }
1092
1093 #[test]
1094 fn test_count() -> Result<()> {
1095 let db = Connection::open_in_memory()?;
1096 db.create_aggregate_function(
1097 c"my_count",
1098 -1,
1099 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1100 Count,
1101 )?;
1102
1103 let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1105 assert_eq!(db.one_column::<i64, _>(no_result, [])?, 0);
1106
1107 let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1108 assert_eq!(2, db.one_column::<i64, _>(single_sum, [])?);
1109 Ok(())
1110 }
1111
1112 #[cfg(feature = "window")]
1113 impl WindowAggregate<i64, Option<i64>> for Sum {
1114 fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1115 *sum -= ctx.get::<i64>(0)?;
1116 Ok(())
1117 }
1118
1119 fn value(&self, sum: Option<&mut i64>) -> Result<Option<i64>> {
1120 Ok(sum.copied())
1121 }
1122 }
1123
1124 #[test]
1125 #[cfg(feature = "window")]
1126 fn test_window() -> Result<()> {
1127 use fallible_iterator::FallibleIterator as _;
1128
1129 let db = Connection::open_in_memory()?;
1130 db.create_window_function(
1131 c"sumint",
1132 1,
1133 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1134 Sum,
1135 )?;
1136 db.execute_batch(
1137 "CREATE TABLE t3(x, y);
1138 INSERT INTO t3 VALUES('a', 4),
1139 ('b', 5),
1140 ('c', 3),
1141 ('d', 8),
1142 ('e', 1);",
1143 )?;
1144
1145 let mut stmt = db.prepare(
1146 "SELECT x, sumint(y) OVER (
1147 ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1148 ) AS sum_y
1149 FROM t3 ORDER BY x;",
1150 )?;
1151
1152 let results: Vec<(String, i64)> = stmt
1153 .query([])?
1154 .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1155 .collect()?;
1156 let expected = vec![
1157 ("a".to_owned(), 9),
1158 ("b".to_owned(), 12),
1159 ("c".to_owned(), 16),
1160 ("d".to_owned(), 12),
1161 ("e".to_owned(), 9),
1162 ];
1163 assert_eq!(expected, results);
1164 Ok(())
1165 }
1166
1167 #[test]
1168 fn test_sub_type() -> Result<()> {
1169 fn test_getsubtype(ctx: &Context<'_>) -> Result<i32> {
1170 Ok(ctx.get_subtype(0) as i32)
1171 }
1172 fn test_setsubtype(ctx: &Context<'_>) -> Result<(SqlFnArg, SubType)> {
1173 use std::ffi::c_uint;
1174 let value = ctx.get_arg(0);
1175 let sub_type = ctx.get::<c_uint>(1)?;
1176 Ok((value, Some(sub_type)))
1177 }
1178 let db = Connection::open_in_memory()?;
1179 db.create_scalar_function(
1180 c"test_getsubtype",
1181 1,
1182 FunctionFlags::SQLITE_UTF8,
1183 test_getsubtype,
1184 )?;
1185 db.create_scalar_function(
1186 c"test_setsubtype",
1187 2,
1188 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_RESULT_SUBTYPE,
1189 test_setsubtype,
1190 )?;
1191 let result: i32 = db.one_column("SELECT test_getsubtype('hello');", [])?;
1192 assert_eq!(0, result);
1193
1194 let result: i32 =
1195 db.one_column("SELECT test_getsubtype(test_setsubtype('hello',123));", [])?;
1196 assert_eq!(123, result);
1197
1198 Ok(())
1199 }
1200
1201 #[test]
1202 fn test_blob() -> Result<()> {
1203 fn test_len(ctx: &Context<'_>) -> Result<u32> {
1204 let blob = ctx.get_raw(0);
1205 Ok(blob
1206 .as_bytes_or_null()?
1207 .map_or(0, |b| b.len().try_into().unwrap()))
1208 }
1209 let db = Connection::open_in_memory()?;
1210 db.create_scalar_function("test_len", 1, FunctionFlags::SQLITE_DETERMINISTIC, test_len)?;
1211 assert_eq!(
1212 6,
1213 db.one_column::<u32, _>("SELECT test_len(X'53514C697465');", [])?
1214 );
1215 assert_eq!(0, db.one_column::<u32, _>("SELECT test_len(X'');", [])?);
1216 assert_eq!(0, db.one_column::<u32, _>("SELECT test_len(NULL);", [])?);
1217 Ok(())
1218 }
1219
1220 #[test]
1221 #[cfg(feature = "pointer")]
1222 fn test_rc_pointer() -> Result<()> {
1223 use crate::types::ToSqlOutput;
1224 use std::ops::Deref as _;
1225 use std::rc::Rc;
1226
1227 const PTR_TYPE: &std::ffi::CStr = c"my_rust_ptr";
1228 let rc = Rc::new(1);
1229 {
1230 let ptr = ToSqlOutput::from_rc(rc.clone(), PTR_TYPE);
1231 assert_eq!(2, Rc::strong_count(&rc));
1232 fn myfunc(ctx: &Context<'_>) -> Result<ToSqlOutput<'static>> {
1233 let x = unsafe { ctx.get_pointer(0, PTR_TYPE) };
1234 assert_eq!(x, Some(&1));
1235 Ok(ToSqlOutput::from_rc(Rc::new(*x.unwrap()), PTR_TYPE))
1236 }
1237 let db = Connection::open_in_memory()?;
1238 db.create_scalar_function("myfunc", 1, FunctionFlags::SQLITE_DETERMINISTIC, myfunc)?;
1239 let mut stmt = db.prepare("SELECT myfunc(?)")?;
1240 let result = stmt.query_one([ptr], |r| {
1241 unsafe { r.get_pointer::<_, i32>(0, PTR_TYPE) }.map(|opt| opt.cloned())
1242 })?;
1243 assert_eq!(result.unwrap(), *rc.deref());
1244 }
1245 assert_eq!(1, Rc::strong_count(&rc));
1246 Ok(())
1247 }
1248
1249 #[test]
1250 #[cfg(feature = "pointer")]
1251 fn test_box_pointer() -> Result<()> {
1252 use crate::types::ToSqlOutput;
1253
1254 const PTR_TYPE: &std::ffi::CStr = c"my_rust_ptr";
1255 let value = 1;
1256 let ptr = ToSqlOutput::new_boxed(value, PTR_TYPE);
1257 fn myfunc(ctx: &Context<'_>) -> Result<ToSqlOutput<'static>> {
1258 let x = unsafe { ctx.get_pointer(0, PTR_TYPE) };
1259 assert_eq!(x, Some(&1));
1260 Ok(ToSqlOutput::new_boxed(*x.unwrap(), PTR_TYPE))
1261 }
1262 let db = Connection::open_in_memory()?;
1263 db.create_scalar_function("myfunc", 1, FunctionFlags::SQLITE_DETERMINISTIC, myfunc)?;
1264 let mut stmt = db.prepare("SELECT myfunc(?)")?;
1265 let result = stmt.query_one([ptr], |r| {
1266 unsafe { r.get_pointer::<_, i32>(0, PTR_TYPE) }.map(|opt| opt.cloned())
1267 })?;
1268 assert_eq!(result.unwrap(), value);
1269 Ok(())
1270 }
1271}