1use std::any::Any;
56use std::marker::PhantomData;
57use std::ops::Deref;
58use std::os::raw::{c_int, c_void};
59use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60use std::ptr;
61use std::slice;
62use std::sync::Arc;
63
64use crate::ffi;
65use crate::ffi::sqlite3_context;
66use crate::ffi::sqlite3_value;
67
68use crate::context::set_result;
69use crate::types::{FromSql, FromSqlError, ToSql, ValueRef};
70
71use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};
72
73unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74 fn constraint_error_code() -> i32 {
79 ffi::SQLITE_CONSTRAINT_FUNCTION
80 }
81
82 if let Error::SqliteFailure(ref err, ref s) = *err {
83 ffi::sqlite3_result_error_code(ctx, err.extended_code);
84 if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
85 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
86 }
87 } else {
88 ffi::sqlite3_result_error_code(ctx, constraint_error_code());
89 if let Ok(cstr) = str_to_cstring(&err.to_string()) {
90 ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
91 }
92 }
93}
94
95unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
96 drop(Box::from_raw(p.cast::<T>()));
97}
98
99pub struct Context<'a> {
102 ctx: *mut sqlite3_context,
103 args: &'a [*mut sqlite3_value],
104}
105
106impl Context<'_> {
107 #[inline]
109 #[must_use]
110 pub fn len(&self) -> usize {
111 self.args.len()
112 }
113
114 #[inline]
116 #[must_use]
117 pub fn is_empty(&self) -> bool {
118 self.args.is_empty()
119 }
120
121 pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
131 let arg = self.args[idx];
132 let value = unsafe { ValueRef::from_value(arg) };
133 FromSql::column_result(value).map_err(|err| match err {
134 FromSqlError::InvalidType => {
135 Error::InvalidFunctionParameterType(idx, value.data_type())
136 }
137 FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
138 FromSqlError::Other(err) => {
139 Error::FromSqlConversionFailure(idx, value.data_type(), err)
140 }
141 FromSqlError::InvalidBlobSize { .. } => {
142 Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
143 }
144 })
145 }
146
147 #[inline]
154 #[must_use]
155 pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
156 let arg = self.args[idx];
157 unsafe { ValueRef::from_value(arg) }
158 }
159
160 pub fn get_subtype(&self, idx: usize) -> std::os::raw::c_uint {
167 let arg = self.args[idx];
168 unsafe { ffi::sqlite3_value_subtype(arg) }
169 }
170
171 pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
179 where
180 T: Send + Sync + 'static,
181 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
182 F: FnOnce(ValueRef<'_>) -> Result<T, E>,
183 {
184 if let Some(v) = self.get_aux(arg)? {
185 Ok(v)
186 } else {
187 let vr = self.get_raw(arg as usize);
188 self.set_aux(
189 arg,
190 func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
191 )
192 }
193 }
194
195 pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
199 let orig: Arc<T> = Arc::new(value);
200 let inner: AuxInner = orig.clone();
201 let outer = Box::new(inner);
202 let raw: *mut AuxInner = Box::into_raw(outer);
203 unsafe {
204 ffi::sqlite3_set_auxdata(
205 self.ctx,
206 arg,
207 raw.cast(),
208 Some(free_boxed_value::<AuxInner>),
209 );
210 };
211 Ok(orig)
212 }
213
214 pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
219 let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
220 if p.is_null() {
221 Ok(None)
222 } else {
223 let v: AuxInner = AuxInner::clone(unsafe { &*p });
224 v.downcast::<T>()
225 .map(Some)
226 .map_err(|_| Error::GetAuxWrongType)
227 }
228 }
229
230 pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> {
237 let handle = ffi::sqlite3_context_db_handle(self.ctx);
238 Ok(ConnectionRef {
239 conn: Connection::from_handle(handle)?,
240 phantom: PhantomData,
241 })
242 }
243
244 pub fn set_result_subtype(&self, sub_type: std::os::raw::c_uint) {
246 unsafe { ffi::sqlite3_result_subtype(self.ctx, sub_type) };
247 }
248}
249
250pub struct ConnectionRef<'ctx> {
252 conn: Connection,
255 phantom: PhantomData<&'ctx Context<'ctx>>,
256}
257
258impl Deref for ConnectionRef<'_> {
259 type Target = Connection;
260
261 #[inline]
262 fn deref(&self) -> &Connection {
263 &self.conn
264 }
265}
266
267type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
268
269pub trait Aggregate<A, T>
275where
276 A: RefUnwindSafe + UnwindSafe,
277 T: ToSql,
278{
279 fn init(&self, _: &mut Context<'_>) -> Result<A>;
284
285 fn step(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
288
289 fn finalize(&self, _: &mut Context<'_>, _: Option<A>) -> Result<T>;
299}
300
301#[cfg(feature = "window")]
304#[cfg_attr(docsrs, doc(cfg(feature = "window")))]
305pub trait WindowAggregate<A, T>: Aggregate<A, T>
306where
307 A: RefUnwindSafe + UnwindSafe,
308 T: ToSql,
309{
310 fn value(&self, _: Option<&A>) -> Result<T>;
313
314 fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
316}
317
318bitflags::bitflags! {
319 #[repr(C)]
323 pub struct FunctionFlags: ::std::os::raw::c_int {
324 const SQLITE_UTF8 = ffi::SQLITE_UTF8;
326 const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE;
328 const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE;
330 const SQLITE_UTF16 = ffi::SQLITE_UTF16;
332 const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; const SQLITE_DIRECTONLY = 0x0000_0008_0000; const SQLITE_SUBTYPE = 0x0000_0010_0000; const SQLITE_INNOCUOUS = 0x0000_0020_0000; }
341}
342
343impl Default for FunctionFlags {
344 #[inline]
345 fn default() -> FunctionFlags {
346 FunctionFlags::SQLITE_UTF8
347 }
348}
349
350impl Connection {
351 #[inline]
389 pub fn create_scalar_function<F, T>(
390 &self,
391 fn_name: &str,
392 n_arg: c_int,
393 flags: FunctionFlags,
394 x_func: F,
395 ) -> Result<()>
396 where
397 F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
398 T: ToSql,
399 {
400 self.db
401 .borrow_mut()
402 .create_scalar_function(fn_name, n_arg, flags, x_func)
403 }
404
405 #[inline]
412 pub fn create_aggregate_function<A, D, T>(
413 &self,
414 fn_name: &str,
415 n_arg: c_int,
416 flags: FunctionFlags,
417 aggr: D,
418 ) -> Result<()>
419 where
420 A: RefUnwindSafe + UnwindSafe,
421 D: Aggregate<A, T> + 'static,
422 T: ToSql,
423 {
424 self.db
425 .borrow_mut()
426 .create_aggregate_function(fn_name, n_arg, flags, aggr)
427 }
428
429 #[cfg(feature = "window")]
435 #[cfg_attr(docsrs, doc(cfg(feature = "window")))]
436 #[inline]
437 pub fn create_window_function<A, W, T>(
438 &self,
439 fn_name: &str,
440 n_arg: c_int,
441 flags: FunctionFlags,
442 aggr: W,
443 ) -> Result<()>
444 where
445 A: RefUnwindSafe + UnwindSafe,
446 W: WindowAggregate<A, T> + 'static,
447 T: ToSql,
448 {
449 self.db
450 .borrow_mut()
451 .create_window_function(fn_name, n_arg, flags, aggr)
452 }
453
454 #[inline]
465 pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> {
466 self.db.borrow_mut().remove_function(fn_name, n_arg)
467 }
468}
469
470impl InnerConnection {
471 fn create_scalar_function<F, T>(
472 &mut self,
473 fn_name: &str,
474 n_arg: c_int,
475 flags: FunctionFlags,
476 x_func: F,
477 ) -> Result<()>
478 where
479 F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
480 T: ToSql,
481 {
482 unsafe extern "C" fn call_boxed_closure<F, T>(
483 ctx: *mut sqlite3_context,
484 argc: c_int,
485 argv: *mut *mut sqlite3_value,
486 ) where
487 F: FnMut(&Context<'_>) -> Result<T>,
488 T: ToSql,
489 {
490 let r = catch_unwind(|| {
491 let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>();
492 assert!(!boxed_f.is_null(), "Internal error - null function pointer");
493 let ctx = Context {
494 ctx,
495 args: slice::from_raw_parts(argv, argc as usize),
496 };
497 (*boxed_f)(&ctx)
498 });
499 let t = match r {
500 Err(_) => {
501 report_error(ctx, &Error::UnwindingPanic);
502 return;
503 }
504 Ok(r) => r,
505 };
506 let t = t.as_ref().map(|t| ToSql::to_sql(t));
507
508 match t {
509 Ok(Ok(ref value)) => set_result(ctx, value),
510 Ok(Err(err)) => report_error(ctx, &err),
511 Err(err) => report_error(ctx, err),
512 }
513 }
514
515 let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
516 let c_name = str_to_cstring(fn_name)?;
517 let r = unsafe {
518 ffi::sqlite3_create_function_v2(
519 self.db(),
520 c_name.as_ptr(),
521 n_arg,
522 flags.bits(),
523 boxed_f.cast::<c_void>(),
524 Some(call_boxed_closure::<F, T>),
525 None,
526 None,
527 Some(free_boxed_value::<F>),
528 )
529 };
530 self.decode_result(r)
531 }
532
533 fn create_aggregate_function<A, D, T>(
534 &mut self,
535 fn_name: &str,
536 n_arg: c_int,
537 flags: FunctionFlags,
538 aggr: D,
539 ) -> Result<()>
540 where
541 A: RefUnwindSafe + UnwindSafe,
542 D: Aggregate<A, T> + 'static,
543 T: ToSql,
544 {
545 let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
546 let c_name = str_to_cstring(fn_name)?;
547 let r = unsafe {
548 ffi::sqlite3_create_function_v2(
549 self.db(),
550 c_name.as_ptr(),
551 n_arg,
552 flags.bits(),
553 boxed_aggr.cast::<c_void>(),
554 None,
555 Some(call_boxed_step::<A, D, T>),
556 Some(call_boxed_final::<A, D, T>),
557 Some(free_boxed_value::<D>),
558 )
559 };
560 self.decode_result(r)
561 }
562
563 #[cfg(feature = "window")]
564 fn create_window_function<A, W, T>(
565 &mut self,
566 fn_name: &str,
567 n_arg: c_int,
568 flags: FunctionFlags,
569 aggr: W,
570 ) -> Result<()>
571 where
572 A: RefUnwindSafe + UnwindSafe,
573 W: WindowAggregate<A, T> + 'static,
574 T: ToSql,
575 {
576 let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
577 let c_name = str_to_cstring(fn_name)?;
578 let r = unsafe {
579 ffi::sqlite3_create_window_function(
580 self.db(),
581 c_name.as_ptr(),
582 n_arg,
583 flags.bits(),
584 boxed_aggr.cast::<c_void>(),
585 Some(call_boxed_step::<A, W, T>),
586 Some(call_boxed_final::<A, W, T>),
587 Some(call_boxed_value::<A, W, T>),
588 Some(call_boxed_inverse::<A, W, T>),
589 Some(free_boxed_value::<W>),
590 )
591 };
592 self.decode_result(r)
593 }
594
595 fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
596 let c_name = str_to_cstring(fn_name)?;
597 let r = unsafe {
598 ffi::sqlite3_create_function_v2(
599 self.db(),
600 c_name.as_ptr(),
601 n_arg,
602 ffi::SQLITE_UTF8,
603 ptr::null_mut(),
604 None,
605 None,
606 None,
607 None,
608 )
609 };
610 self.decode_result(r)
611 }
612}
613
614unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
615 let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
616 if pac.is_null() {
617 return None;
618 }
619 Some(pac)
620}
621
622unsafe extern "C" fn call_boxed_step<A, D, T>(
623 ctx: *mut sqlite3_context,
624 argc: c_int,
625 argv: *mut *mut sqlite3_value,
626) where
627 A: RefUnwindSafe + UnwindSafe,
628 D: Aggregate<A, T>,
629 T: ToSql,
630{
631 let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
632 pac
633 } else {
634 ffi::sqlite3_result_error_nomem(ctx);
635 return;
636 };
637
638 let r = catch_unwind(|| {
639 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
640 assert!(
641 !boxed_aggr.is_null(),
642 "Internal error - null aggregate pointer"
643 );
644 let mut ctx = Context {
645 ctx,
646 args: slice::from_raw_parts(argv, argc as usize),
647 };
648
649 if (*pac as *mut A).is_null() {
650 *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?));
651 }
652
653 (*boxed_aggr).step(&mut ctx, &mut **pac)
654 });
655 let r = match r {
656 Err(_) => {
657 report_error(ctx, &Error::UnwindingPanic);
658 return;
659 }
660 Ok(r) => r,
661 };
662 match r {
663 Ok(_) => {}
664 Err(err) => report_error(ctx, &err),
665 };
666}
667
668#[cfg(feature = "window")]
669unsafe extern "C" fn call_boxed_inverse<A, W, T>(
670 ctx: *mut sqlite3_context,
671 argc: c_int,
672 argv: *mut *mut sqlite3_value,
673) where
674 A: RefUnwindSafe + UnwindSafe,
675 W: WindowAggregate<A, T>,
676 T: ToSql,
677{
678 let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
679 pac
680 } else {
681 ffi::sqlite3_result_error_nomem(ctx);
682 return;
683 };
684
685 let r = catch_unwind(|| {
686 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
687 assert!(
688 !boxed_aggr.is_null(),
689 "Internal error - null aggregate pointer"
690 );
691 let mut ctx = Context {
692 ctx,
693 args: slice::from_raw_parts(argv, argc as usize),
694 };
695 (*boxed_aggr).inverse(&mut ctx, &mut **pac)
696 });
697 let r = match r {
698 Err(_) => {
699 report_error(ctx, &Error::UnwindingPanic);
700 return;
701 }
702 Ok(r) => r,
703 };
704 match r {
705 Ok(_) => {}
706 Err(err) => report_error(ctx, &err),
707 };
708}
709
710unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
711where
712 A: RefUnwindSafe + UnwindSafe,
713 D: Aggregate<A, T>,
714 T: ToSql,
715{
716 let a: Option<A> = match aggregate_context(ctx, 0) {
719 Some(pac) => {
720 if (*pac as *mut A).is_null() {
721 None
722 } else {
723 let a = Box::from_raw(*pac);
724 Some(*a)
725 }
726 }
727 None => None,
728 };
729
730 let r = catch_unwind(|| {
731 let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
732 assert!(
733 !boxed_aggr.is_null(),
734 "Internal error - null aggregate pointer"
735 );
736 let mut ctx = Context { ctx, args: &mut [] };
737 (*boxed_aggr).finalize(&mut ctx, a)
738 });
739 let t = match r {
740 Err(_) => {
741 report_error(ctx, &Error::UnwindingPanic);
742 return;
743 }
744 Ok(r) => r,
745 };
746 let t = t.as_ref().map(|t| ToSql::to_sql(t));
747 match t {
748 Ok(Ok(ref value)) => set_result(ctx, value),
749 Ok(Err(err)) => report_error(ctx, &err),
750 Err(err) => report_error(ctx, err),
751 }
752}
753
754#[cfg(feature = "window")]
755unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
756where
757 A: RefUnwindSafe + UnwindSafe,
758 W: WindowAggregate<A, T>,
759 T: ToSql,
760{
761 let a: Option<&A> = match aggregate_context(ctx, 0) {
764 Some(pac) => {
765 if (*pac as *mut A).is_null() {
766 None
767 } else {
768 let a = &**pac;
769 Some(a)
770 }
771 }
772 None => None,
773 };
774
775 let r = catch_unwind(|| {
776 let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
777 assert!(
778 !boxed_aggr.is_null(),
779 "Internal error - null aggregate pointer"
780 );
781 (*boxed_aggr).value(a)
782 });
783 let t = match r {
784 Err(_) => {
785 report_error(ctx, &Error::UnwindingPanic);
786 return;
787 }
788 Ok(r) => r,
789 };
790 let t = t.as_ref().map(|t| ToSql::to_sql(t));
791 match t {
792 Ok(Ok(ref value)) => set_result(ctx, value),
793 Ok(Err(err)) => report_error(ctx, &err),
794 Err(err) => report_error(ctx, err),
795 }
796}
797
798#[cfg(test)]
799mod test {
800 use regex::Regex;
801 use std::os::raw::c_double;
802
803 #[cfg(feature = "window")]
804 use crate::functions::WindowAggregate;
805 use crate::functions::{Aggregate, Context, FunctionFlags};
806 use crate::{Connection, Error, Result};
807
808 fn half(ctx: &Context<'_>) -> Result<c_double> {
809 assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
810 let value = ctx.get::<c_double>(0)?;
811 Ok(value / 2f64)
812 }
813
814 #[test]
815 fn test_function_half() -> Result<()> {
816 let db = Connection::open_in_memory()?;
817 db.create_scalar_function(
818 "half",
819 1,
820 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
821 half,
822 )?;
823 let result: f64 = db.one_column("SELECT half(6)")?;
824
825 assert!((3f64 - result).abs() < f64::EPSILON);
826 Ok(())
827 }
828
829 #[test]
830 fn test_remove_function() -> Result<()> {
831 let db = Connection::open_in_memory()?;
832 db.create_scalar_function(
833 "half",
834 1,
835 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
836 half,
837 )?;
838 let result: f64 = db.one_column("SELECT half(6)")?;
839 assert!((3f64 - result).abs() < f64::EPSILON);
840
841 db.remove_function("half", 1)?;
842 let result: Result<f64> = db.one_column("SELECT half(6)");
843 result.unwrap_err();
844 Ok(())
845 }
846
847 fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
851 assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
852 type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
853 let regexp: std::sync::Arc<Regex> = ctx
854 .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
855 Ok(Regex::new(vr.as_str()?)?)
856 })?;
857
858 let is_match = {
859 let text = ctx
860 .get_raw(1)
861 .as_str()
862 .map_err(|e| Error::UserFunctionError(e.into()))?;
863
864 regexp.is_match(text)
865 };
866
867 Ok(is_match)
868 }
869
870 #[test]
871 fn test_function_regexp_with_auxilliary() -> Result<()> {
872 let db = Connection::open_in_memory()?;
873 db.execute_batch(
874 "BEGIN;
875 CREATE TABLE foo (x string);
876 INSERT INTO foo VALUES ('lisa');
877 INSERT INTO foo VALUES ('lXsi');
878 INSERT INTO foo VALUES ('lisX');
879 END;",
880 )?;
881 db.create_scalar_function(
882 "regexp",
883 2,
884 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
885 regexp_with_auxilliary,
886 )?;
887
888 let result: bool = db.one_column("SELECT regexp('l.s[aeiouy]', 'lisa')")?;
889
890 assert!(result);
891
892 let result: i64 =
893 db.one_column("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1")?;
894
895 assert_eq!(2, result);
896 Ok(())
897 }
898
899 #[test]
900 fn test_varargs_function() -> Result<()> {
901 let db = Connection::open_in_memory()?;
902 db.create_scalar_function(
903 "my_concat",
904 -1,
905 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
906 |ctx| {
907 let mut ret = String::new();
908
909 for idx in 0..ctx.len() {
910 let s = ctx.get::<String>(idx)?;
911 ret.push_str(&s);
912 }
913
914 Ok(ret)
915 },
916 )?;
917
918 for &(expected, query) in &[
919 ("", "SELECT my_concat()"),
920 ("onetwo", "SELECT my_concat('one', 'two')"),
921 ("abc", "SELECT my_concat('a', 'b', 'c')"),
922 ] {
923 let result: String = db.one_column(query)?;
924 assert_eq!(expected, result);
925 }
926 Ok(())
927 }
928
929 #[test]
930 fn test_get_aux_type_checking() -> Result<()> {
931 let db = Connection::open_in_memory()?;
932 db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
933 if !ctx.get::<bool>(1)? {
934 ctx.set_aux::<i64>(0, 100)?;
935 } else {
936 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
937 assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100);
938 }
939 Ok(true)
940 })?;
941
942 let res: bool =
943 db.one_column("SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)")?;
944 assert!(res);
946 Ok(())
947 }
948
949 struct Sum;
950 struct Count;
951
952 impl Aggregate<i64, Option<i64>> for Sum {
953 fn init(&self, _: &mut Context<'_>) -> Result<i64> {
954 Ok(0)
955 }
956
957 fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
958 *sum += ctx.get::<i64>(0)?;
959 Ok(())
960 }
961
962 fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> {
963 Ok(sum)
964 }
965 }
966
967 impl Aggregate<i64, i64> for Count {
968 fn init(&self, _: &mut Context<'_>) -> Result<i64> {
969 Ok(0)
970 }
971
972 fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
973 *sum += 1;
974 Ok(())
975 }
976
977 fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> {
978 Ok(sum.unwrap_or(0))
979 }
980 }
981
982 #[test]
983 fn test_sum() -> Result<()> {
984 let db = Connection::open_in_memory()?;
985 db.create_aggregate_function(
986 "my_sum",
987 1,
988 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
989 Sum,
990 )?;
991
992 let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
994 let result: Option<i64> = db.one_column(no_result)?;
995 assert!(result.is_none());
996
997 let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
998 let result: i64 = db.one_column(single_sum)?;
999 assert_eq!(4, result);
1000
1001 let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
1002 2, 1)";
1003 let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?;
1004 assert_eq!((4, 2), result);
1005 Ok(())
1006 }
1007
1008 #[test]
1009 fn test_count() -> Result<()> {
1010 let db = Connection::open_in_memory()?;
1011 db.create_aggregate_function(
1012 "my_count",
1013 -1,
1014 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1015 Count,
1016 )?;
1017
1018 let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1020 let result: i64 = db.one_column(no_result)?;
1021 assert_eq!(result, 0);
1022
1023 let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1024 let result: i64 = db.one_column(single_sum)?;
1025 assert_eq!(2, result);
1026 Ok(())
1027 }
1028
1029 #[cfg(feature = "window")]
1030 impl WindowAggregate<i64, Option<i64>> for Sum {
1031 fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1032 *sum -= ctx.get::<i64>(0)?;
1033 Ok(())
1034 }
1035
1036 fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> {
1037 Ok(sum.copied())
1038 }
1039 }
1040
1041 #[test]
1042 #[cfg(feature = "window")]
1043 fn test_window() -> Result<()> {
1044 use fallible_iterator::FallibleIterator;
1045
1046 let db = Connection::open_in_memory()?;
1047 db.create_window_function(
1048 "sumint",
1049 1,
1050 FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1051 Sum,
1052 )?;
1053 db.execute_batch(
1054 "CREATE TABLE t3(x, y);
1055 INSERT INTO t3 VALUES('a', 4),
1056 ('b', 5),
1057 ('c', 3),
1058 ('d', 8),
1059 ('e', 1);",
1060 )?;
1061
1062 let mut stmt = db.prepare(
1063 "SELECT x, sumint(y) OVER (
1064 ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1065 ) AS sum_y
1066 FROM t3 ORDER BY x;",
1067 )?;
1068
1069 let results: Vec<(String, i64)> = stmt
1070 .query([])?
1071 .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1072 .collect()?;
1073 let expected = vec![
1074 ("a".to_owned(), 9),
1075 ("b".to_owned(), 12),
1076 ("c".to_owned(), 16),
1077 ("d".to_owned(), 12),
1078 ("e".to_owned(), 9),
1079 ];
1080 assert_eq!(expected, results);
1081 Ok(())
1082 }
1083}