1#![allow(clippy::unnecessary_literal_bound)]
13
14use std::any::Any;
15use std::collections::HashMap;
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::{Arc, OnceLock};
18
19use fsqlite_error::FrankenError;
20use fsqlite_types::SqliteValue;
21use tracing::debug;
22
23static FSQLITE_FUNC_CALLS_TOTAL: AtomicU64 = AtomicU64::new(0);
27static FSQLITE_FUNC_EVAL_DURATION_US_TOTAL: AtomicU64 = AtomicU64::new(0);
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub struct FuncMetricsSnapshot {
33 pub calls_total: u64,
35 pub eval_duration_us_total: u64,
37}
38
39#[must_use]
41pub fn func_metrics_snapshot() -> FuncMetricsSnapshot {
42 FuncMetricsSnapshot {
43 calls_total: FSQLITE_FUNC_CALLS_TOTAL.load(Ordering::Relaxed),
44 eval_duration_us_total: FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.load(Ordering::Relaxed),
45 }
46}
47
48pub fn reset_func_metrics() {
50 FSQLITE_FUNC_CALLS_TOTAL.store(0, Ordering::Relaxed);
51 FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.store(0, Ordering::Relaxed);
52}
53
54pub fn record_func_call(duration_us: u64) {
56 FSQLITE_FUNC_CALLS_TOTAL.fetch_add(1, Ordering::Relaxed);
57 FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.fetch_add(duration_us, Ordering::Relaxed);
58}
59
60pub fn record_func_call_count_only() {
62 FSQLITE_FUNC_CALLS_TOTAL.fetch_add(1, Ordering::Relaxed);
63}
64
65static FSQLITE_UDF_REGISTERED: AtomicU64 = AtomicU64::new(0);
69
70pub fn record_udf_registered() {
72 FSQLITE_UDF_REGISTERED.fetch_add(1, Ordering::Relaxed);
73}
74
75#[must_use]
77pub fn udf_registered_count() -> u64 {
78 FSQLITE_UDF_REGISTERED.load(Ordering::Relaxed)
79}
80
81pub fn reset_udf_metrics() {
83 FSQLITE_UDF_REGISTERED.store(0, Ordering::Relaxed);
84}
85
86pub mod agg_builtins;
87pub mod aggregate;
88pub mod authorizer;
89pub mod builtins;
90pub mod collation;
91pub mod datetime;
92pub mod math;
93pub mod scalar;
94pub mod vtab;
95pub mod window;
96pub mod window_builtins;
97
98pub use agg_builtins::register_aggregate_builtins;
99pub use aggregate::{AggregateAdapter, AggregateFunction};
100pub use authorizer::{AuthAction, AuthResult, Authorizer, AuthorizerAction, AuthorizerDecision};
101pub use builtins::{
102 ChangeTrackingState, get_last_changes, get_last_insert_rowid, get_total_changes,
103 register_builtins, reset_total_changes, set_change_tracking_state, set_last_changes,
104 set_last_insert_rowid, sqlite_compile_options, sqlite_compileoption_used,
105};
106pub use collation::{
107 BinaryCollation, CollationAnnotation, CollationFunction, CollationRegistry, CollationSource,
108 NoCaseCollation, RtrimCollation, resolve_collation,
109};
110pub use datetime::register_datetime_builtins;
111pub use math::register_math_builtins;
112pub use scalar::ScalarFunction;
113pub use vtab::{
114 ColumnContext, ConstraintOp, IndexConstraint, IndexConstraintUsage, IndexInfo, IndexOrderBy,
115 VirtualTable, VirtualTableCursor,
116};
117pub use window::{WindowAdapter, WindowFunction};
118pub use window_builtins::register_window_builtins;
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
122pub enum BuiltinFunctionFamily {
123 Scalar,
124 Aggregate,
125 Window,
126}
127
128impl BuiltinFunctionFamily {
129 #[must_use]
130 pub const fn label(self) -> &'static str {
131 match self {
132 Self::Scalar => "scalar",
133 Self::Aggregate => "aggregate",
134 Self::Window => "window",
135 }
136 }
137}
138
139#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
141pub enum BuiltinFunctionClass {
142 CoreScalar,
143 MathScalar,
144 DateTimeScalar,
145 Aggregate,
146 Window,
147}
148
149impl BuiltinFunctionClass {
150 #[must_use]
151 pub const fn label(self) -> &'static str {
152 match self {
153 Self::CoreScalar => "core_scalar",
154 Self::MathScalar => "math_scalar",
155 Self::DateTimeScalar => "datetime_scalar",
156 Self::Aggregate => "aggregate",
157 Self::Window => "window",
158 }
159 }
160}
161
162#[derive(Debug, Clone, PartialEq, Eq)]
164pub struct BuiltinFunctionSurfaceEntry {
165 pub name: String,
167 pub num_args: i32,
169 pub family: BuiltinFunctionFamily,
171 pub class: BuiltinFunctionClass,
173 pub is_alias: bool,
175 pub surface_id: &'static str,
177}
178
179const CORE_FUNCTION_SURFACE_ID: &str = "SURF-FUNC-CORE-011";
180const WINDOW_FUNCTION_SURFACE_ID: &str = "SURF-FUNC-WINDOW-012";
181
182#[must_use]
188pub fn builtin_function_surface_inventory() -> &'static [BuiltinFunctionSurfaceEntry] {
189 static INVENTORY: OnceLock<Vec<BuiltinFunctionSurfaceEntry>> = OnceLock::new();
190 INVENTORY
191 .get_or_init(|| {
192 let mut registry = FunctionRegistry::new();
193 register_builtins(&mut registry);
194 register_window_builtins(&mut registry);
195
196 let mut entries = Vec::with_capacity(
197 registry.scalars.len() + registry.aggregates.len() + registry.windows.len(),
198 );
199 extend_builtin_surface_entries(
200 &mut entries,
201 BuiltinFunctionFamily::Scalar,
202 registry.scalars.keys(),
203 );
204 extend_builtin_surface_entries(
205 &mut entries,
206 BuiltinFunctionFamily::Aggregate,
207 registry.aggregates.keys(),
208 );
209 extend_builtin_surface_entries(
210 &mut entries,
211 BuiltinFunctionFamily::Window,
212 registry.windows.keys(),
213 );
214 entries.sort_by(|left, right| {
215 (left.family, left.class, &left.name, left.num_args).cmp(&(
216 right.family,
217 right.class,
218 &right.name,
219 right.num_args,
220 ))
221 });
222 entries
223 })
224 .as_slice()
225}
226
227pub type ErasedAggregateFunction = dyn AggregateFunction<State = Box<dyn Any + Send>>;
229
230pub type ErasedWindowFunction = dyn WindowFunction<State = Box<dyn Any + Send>>;
232
233#[derive(Debug, Clone, Hash, Eq, PartialEq)]
238pub struct FunctionKey {
239 pub name: String,
241 pub num_args: i32,
243}
244
245impl FunctionKey {
246 #[must_use]
248 pub fn new(name: &str, num_args: i32) -> Self {
249 Self {
250 name: canonical_name(name),
251 num_args,
252 }
253 }
254}
255
256#[derive(Default)]
266pub struct FunctionRegistry {
267 scalars: HashMap<FunctionKey, Arc<dyn ScalarFunction>>,
268 aggregates: HashMap<FunctionKey, Arc<ErasedAggregateFunction>>,
269 windows: HashMap<FunctionKey, Arc<ErasedWindowFunction>>,
270}
271
272struct WrongArgCountScalarFunction {
273 display_name: String,
274}
275
276fn wrong_arg_count_message(display_name: &str) -> String {
277 format!("wrong number of arguments to function {display_name}()")
278}
279
280fn wrong_arg_display_name(canonical: &str) -> String {
281 canonical.to_ascii_lowercase()
282}
283
284impl WrongArgCountScalarFunction {
285 fn new(canonical: &str) -> Self {
286 Self {
287 display_name: wrong_arg_display_name(canonical),
288 }
289 }
290
291 fn message(&self) -> String {
292 wrong_arg_count_message(&self.display_name)
293 }
294}
295
296impl ScalarFunction for WrongArgCountScalarFunction {
297 fn invoke(&self, _args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
298 Err(FrankenError::function_error(self.message()))
299 }
300
301 fn num_args(&self) -> i32 {
302 -1
303 }
304
305 fn name(&self) -> &str {
306 &self.display_name
307 }
308}
309
310struct WrongArgCountAggregateFunction {
311 display_name: String,
312}
313
314impl WrongArgCountAggregateFunction {
315 fn new(canonical: &str) -> Self {
316 Self {
317 display_name: wrong_arg_display_name(canonical),
318 }
319 }
320
321 fn message(&self) -> String {
322 wrong_arg_count_message(&self.display_name)
323 }
324}
325
326impl AggregateFunction for WrongArgCountAggregateFunction {
327 type State = ();
328
329 fn initial_state(&self) -> Self::State {}
330
331 fn step(&self, _state: &mut Self::State, _args: &[SqliteValue]) -> fsqlite_error::Result<()> {
332 Err(FrankenError::function_error(self.message()))
333 }
334
335 fn finalize(&self, _state: Self::State) -> fsqlite_error::Result<SqliteValue> {
336 Err(FrankenError::function_error(self.message()))
337 }
338
339 fn num_args(&self) -> i32 {
340 -1
341 }
342
343 fn name(&self) -> &str {
344 &self.display_name
345 }
346}
347
348struct WrongArgCountWindowFunction {
349 display_name: String,
350}
351
352impl WrongArgCountWindowFunction {
353 fn new(canonical: &str) -> Self {
354 Self {
355 display_name: wrong_arg_display_name(canonical),
356 }
357 }
358
359 fn message(&self) -> String {
360 wrong_arg_count_message(&self.display_name)
361 }
362}
363
364impl WindowFunction for WrongArgCountWindowFunction {
365 type State = ();
366
367 fn initial_state(&self) -> Self::State {}
368
369 fn step(&self, _state: &mut Self::State, _args: &[SqliteValue]) -> fsqlite_error::Result<()> {
370 Err(FrankenError::function_error(self.message()))
371 }
372
373 fn inverse(
374 &self,
375 _state: &mut Self::State,
376 _args: &[SqliteValue],
377 ) -> fsqlite_error::Result<()> {
378 Err(FrankenError::function_error(self.message()))
379 }
380
381 fn value(&self, _state: &Self::State) -> fsqlite_error::Result<SqliteValue> {
382 Err(FrankenError::function_error(self.message()))
383 }
384
385 fn finalize(&self, _state: Self::State) -> fsqlite_error::Result<SqliteValue> {
386 Err(FrankenError::function_error(self.message()))
387 }
388
389 fn num_args(&self) -> i32 {
390 -1
391 }
392
393 fn name(&self) -> &str {
394 &self.display_name
395 }
396}
397
398impl FunctionRegistry {
399 #[must_use]
401 pub fn new() -> Self {
402 Self::default()
403 }
404
405 #[must_use]
410 pub fn clone_from_arc(arc: &Arc<Self>) -> Self {
411 Self {
412 scalars: arc.scalars.clone(),
413 aggregates: arc.aggregates.clone(),
414 windows: arc.windows.clone(),
415 }
416 }
417
418 pub fn register_scalar<F>(&mut self, function: F) -> Option<Arc<dyn ScalarFunction>>
423 where
424 F: ScalarFunction + 'static,
425 {
426 let key = FunctionKey::new(function.name(), function.num_args());
427 self.scalars.insert(key, Arc::new(function))
428 }
429
430 pub fn register_aggregate<F>(&mut self, function: F) -> Option<Arc<ErasedAggregateFunction>>
434 where
435 F: AggregateFunction + 'static,
436 F::State: 'static,
437 {
438 let key = FunctionKey::new(function.name(), function.num_args());
439 self.aggregates
440 .insert(key, Arc::new(AggregateAdapter::new(function)))
441 }
442
443 pub fn register_window<F>(&mut self, function: F) -> Option<Arc<ErasedWindowFunction>>
447 where
448 F: WindowFunction + 'static,
449 F::State: 'static,
450 {
451 let key = FunctionKey::new(function.name(), function.num_args());
452 self.windows
453 .insert(key, Arc::new(WindowAdapter::new(function)))
454 }
455
456 #[must_use]
461 pub fn find_scalar(&self, name: &str, num_args: i32) -> Option<Arc<dyn ScalarFunction>> {
462 let canon = canonical_name(name);
463 self.find_scalar_precanonical(&canon, num_args)
464 }
465
466 #[must_use]
471 pub fn find_scalar_precanonical(
472 &self,
473 canonical: &str,
474 num_args: i32,
475 ) -> Option<Arc<dyn ScalarFunction>> {
476 let exact = FunctionKey {
477 name: canonical.to_owned(),
478 num_args,
479 };
480 if let Some(f) = self.scalars.get(&exact) {
481 debug!(name = %canonical, arity = num_args, kind = "scalar", hit = "exact", "registry lookup");
482 return Some(Arc::clone(f));
483 }
484 let variadic = FunctionKey {
485 name: canonical.to_owned(),
486 num_args: -1,
487 };
488 if let Some(function) = self.scalars.get(&variadic) {
489 if function.accepts_arg_count(num_args) {
490 debug!(name = %canonical, arity = num_args, kind = "scalar", hit = "variadic", "registry lookup");
491 return Some(Arc::clone(function));
492 }
493 debug!(name = %canonical, arity = num_args, kind = "scalar", hit = "wrong_arity", "registry lookup");
494 return Some(Arc::new(WrongArgCountScalarFunction::new(canonical)));
495 }
496 if self.scalars.keys().any(|key| key.name == canonical) {
497 debug!(name = %canonical, arity = num_args, kind = "scalar", hit = "wrong_arity", "registry lookup");
498 return Some(Arc::new(WrongArgCountScalarFunction::new(canonical)));
499 }
500 debug!(
501 name = %canonical,
502 arity = num_args,
503 kind = "scalar",
504 hit = "miss",
505 "registry lookup"
506 );
507 None
508 }
509
510 #[must_use]
514 pub fn find_aggregate(
515 &self,
516 name: &str,
517 num_args: i32,
518 ) -> Option<Arc<ErasedAggregateFunction>> {
519 let canon = canonical_name(name);
520 self.find_aggregate_precanonical(&canon, num_args)
521 }
522
523 #[must_use]
525 pub fn find_aggregate_precanonical(
526 &self,
527 canonical: &str,
528 num_args: i32,
529 ) -> Option<Arc<ErasedAggregateFunction>> {
530 let exact = FunctionKey {
531 name: canonical.to_owned(),
532 num_args,
533 };
534 if let Some(f) = self.aggregates.get(&exact) {
535 debug!(name = %canonical, arity = num_args, kind = "aggregate", hit = "exact", "registry lookup");
536 return Some(Arc::clone(f));
537 }
538 let variadic = FunctionKey {
539 name: canonical.to_owned(),
540 num_args: -1,
541 };
542 if let Some(function) = self.aggregates.get(&variadic) {
543 if function.accepts_arg_count(num_args) {
544 debug!(name = %canonical, arity = num_args, kind = "aggregate", hit = "variadic", "registry lookup");
545 return Some(Arc::clone(function));
546 }
547 debug!(name = %canonical, arity = num_args, kind = "aggregate", hit = "wrong_arity", "registry lookup");
548 return Some(Arc::new(AggregateAdapter::new(
549 WrongArgCountAggregateFunction::new(canonical),
550 )));
551 }
552 if self.aggregates.keys().any(|key| key.name == canonical) {
553 debug!(name = %canonical, arity = num_args, kind = "aggregate", hit = "wrong_arity", "registry lookup");
554 return Some(Arc::new(AggregateAdapter::new(
555 WrongArgCountAggregateFunction::new(canonical),
556 )));
557 }
558 debug!(
559 name = %canonical,
560 arity = num_args,
561 kind = "aggregate",
562 hit = "miss",
563 "registry lookup"
564 );
565 None
566 }
567
568 #[must_use]
572 pub fn find_window(&self, name: &str, num_args: i32) -> Option<Arc<ErasedWindowFunction>> {
573 let canon = canonical_name(name);
574 let exact = FunctionKey {
575 name: canon.clone(),
576 num_args,
577 };
578 if let Some(f) = self.windows.get(&exact) {
579 debug!(name = %canon, arity = num_args, kind = "window", hit = "exact", "registry lookup");
580 return Some(Arc::clone(f));
581 }
582 let variadic = FunctionKey {
583 name: canon.clone(),
584 num_args: -1,
585 };
586 if let Some(function) = self.windows.get(&variadic) {
587 if function.accepts_arg_count(num_args) {
588 debug!(name = %canon, arity = num_args, kind = "window", hit = "variadic", "registry lookup");
589 return Some(Arc::clone(function));
590 }
591 debug!(name = %canon, arity = num_args, kind = "window", hit = "wrong_arity", "registry lookup");
592 return Some(Arc::new(WindowAdapter::new(
593 WrongArgCountWindowFunction::new(&canon),
594 )));
595 }
596 if self.windows.keys().any(|key| key.name == canon) {
597 debug!(name = %canon, arity = num_args, kind = "window", hit = "wrong_arity", "registry lookup");
598 return Some(Arc::new(WindowAdapter::new(
599 WrongArgCountWindowFunction::new(&canon),
600 )));
601 }
602 debug!(
603 name = %canon,
604 arity = num_args,
605 kind = "window",
606 hit = "miss",
607 "registry lookup"
608 );
609 None
610 }
611
612 #[must_use]
615 pub fn contains_scalar(&self, name: &str) -> bool {
616 let canon = canonical_name(name);
617 self.scalars.keys().any(|k| k.name == canon)
618 }
619
620 #[must_use]
623 pub fn contains_aggregate(&self, name: &str) -> bool {
624 let canon = canonical_name(name);
625 self.aggregates.keys().any(|k| k.name == canon)
626 }
627
628 #[must_use]
631 pub fn contains_window(&self, name: &str) -> bool {
632 let canon = canonical_name(name);
633 self.windows.keys().any(|k| k.name == canon)
634 }
635
636 #[must_use]
643 pub fn window_accepts_arg_count(&self, name: &str, num_args: i32) -> Option<bool> {
644 let canon = canonical_name(name);
645 let exact = FunctionKey {
646 name: canon.clone(),
647 num_args,
648 };
649 if let Some(function) = self.windows.get(&exact) {
650 return Some(function.accepts_arg_count(num_args));
651 }
652
653 let variadic = FunctionKey {
654 name: canon.clone(),
655 num_args: -1,
656 };
657 if let Some(function) = self.windows.get(&variadic) {
658 return Some(function.accepts_arg_count(num_args));
659 }
660
661 self.windows
662 .keys()
663 .any(|key| key.name == canon)
664 .then_some(false)
665 }
666
667 #[must_use]
671 pub fn aggregate_names_lowercase(&self) -> Vec<String> {
672 let mut names: Vec<String> = self
673 .aggregates
674 .keys()
675 .map(|k| k.name.to_ascii_lowercase())
676 .collect();
677 names.sort();
678 names.dedup();
679 names
680 }
681}
682
683fn extend_builtin_surface_entries<'a>(
684 entries: &mut Vec<BuiltinFunctionSurfaceEntry>,
685 family: BuiltinFunctionFamily,
686 keys: impl Iterator<Item = &'a FunctionKey>,
687) {
688 for key in keys {
689 let name = key.name.to_ascii_lowercase();
690 let class = builtin_function_class(&name, family);
691 entries.push(BuiltinFunctionSurfaceEntry {
692 is_alias: builtin_function_alias_flag(&name, family),
693 surface_id: builtin_function_surface_id(family),
694 name,
695 num_args: key.num_args,
696 family,
697 class,
698 });
699 }
700}
701
702fn builtin_function_class(name: &str, family: BuiltinFunctionFamily) -> BuiltinFunctionClass {
703 match family {
704 BuiltinFunctionFamily::Aggregate => BuiltinFunctionClass::Aggregate,
705 BuiltinFunctionFamily::Window => BuiltinFunctionClass::Window,
706 BuiltinFunctionFamily::Scalar => {
707 if matches!(
708 name,
709 "acos"
710 | "acosh"
711 | "asin"
712 | "asinh"
713 | "atan"
714 | "atan2"
715 | "atanh"
716 | "ceil"
717 | "ceiling"
718 | "cos"
719 | "cosh"
720 | "degrees"
721 | "exp"
722 | "floor"
723 | "ln"
724 | "log"
725 | "log10"
726 | "log2"
727 | "mod"
728 | "pi"
729 | "pow"
730 | "power"
731 | "radians"
732 | "sin"
733 | "sinh"
734 | "sqrt"
735 | "tan"
736 | "tanh"
737 | "trunc"
738 ) {
739 BuiltinFunctionClass::MathScalar
740 } else if matches!(
741 name,
742 "date" | "datetime" | "julianday" | "strftime" | "time" | "timediff" | "unixepoch"
743 ) {
744 BuiltinFunctionClass::DateTimeScalar
745 } else {
746 BuiltinFunctionClass::CoreScalar
747 }
748 }
749 }
750}
751
752fn builtin_function_alias_flag(name: &str, family: BuiltinFunctionFamily) -> bool {
753 match family {
754 BuiltinFunctionFamily::Scalar => {
755 matches!(name, "ceiling" | "if" | "power" | "printf" | "substring")
756 }
757 BuiltinFunctionFamily::Aggregate | BuiltinFunctionFamily::Window => name == "string_agg",
758 }
759}
760
761const fn builtin_function_surface_id(family: BuiltinFunctionFamily) -> &'static str {
762 match family {
763 BuiltinFunctionFamily::Window => WINDOW_FUNCTION_SURFACE_ID,
764 BuiltinFunctionFamily::Scalar | BuiltinFunctionFamily::Aggregate => {
765 CORE_FUNCTION_SURFACE_ID
766 }
767 }
768}
769
770fn canonical_name(name: &str) -> String {
771 name.trim().to_ascii_uppercase()
772}
773
774#[cfg(test)]
775mod tests {
776 use std::collections::BTreeSet;
777
778 use fsqlite_types::SqliteValue;
779
780 use super::*;
781
782 fn runtime_registry_surface_keys() -> BTreeSet<(BuiltinFunctionFamily, String, i32)> {
783 let mut registry = FunctionRegistry::new();
784 register_builtins(&mut registry);
785 register_window_builtins(&mut registry);
786
787 let scalar_keys = registry
788 .scalars
789 .keys()
790 .map(|key| {
791 (
792 BuiltinFunctionFamily::Scalar,
793 key.name.to_ascii_lowercase(),
794 key.num_args,
795 )
796 })
797 .collect::<BTreeSet<_>>();
798 let aggregate_keys = registry
799 .aggregates
800 .keys()
801 .map(|key| {
802 (
803 BuiltinFunctionFamily::Aggregate,
804 key.name.to_ascii_lowercase(),
805 key.num_args,
806 )
807 })
808 .collect::<BTreeSet<_>>();
809 let window_keys = registry
810 .windows
811 .keys()
812 .map(|key| {
813 (
814 BuiltinFunctionFamily::Window,
815 key.name.to_ascii_lowercase(),
816 key.num_args,
817 )
818 })
819 .collect::<BTreeSet<_>>();
820
821 scalar_keys
822 .into_iter()
823 .chain(aggregate_keys)
824 .chain(window_keys)
825 .collect()
826 }
827
828 fn inventory_surface_keys() -> BTreeSet<(BuiltinFunctionFamily, String, i32)> {
829 builtin_function_surface_inventory()
830 .iter()
831 .map(|entry| (entry.family, entry.name.clone(), entry.num_args))
832 .collect()
833 }
834
835 fn find_surface_entry(
836 family: BuiltinFunctionFamily,
837 name: &str,
838 num_args: i32,
839 ) -> &'static BuiltinFunctionSurfaceEntry {
840 builtin_function_surface_inventory()
841 .iter()
842 .find(|entry| {
843 entry.family == family && entry.name == name && entry.num_args == num_args
844 })
845 .unwrap_or_else(|| {
846 unreachable!(
847 "missing builtin surface entry: family={} name={} arity={}",
848 family.label(),
849 name,
850 num_args
851 )
852 })
853 }
854
855 #[test]
856 fn test_builtin_function_surface_inventory_matches_live_registry() {
857 let inventory = builtin_function_surface_inventory();
858 let inventory_keys = inventory_surface_keys();
859 let runtime_keys = runtime_registry_surface_keys();
860
861 assert_eq!(
862 inventory.len(),
863 inventory_keys.len(),
864 "inventory must not contain duplicate family/name/arity tuples"
865 );
866 assert_eq!(
867 inventory_keys, runtime_keys,
868 "inventory must exactly match the live registration path"
869 );
870 assert!(
871 inventory.windows(2).all(|entries| {
872 (
873 entries[0].family,
874 entries[0].class,
875 &entries[0].name,
876 entries[0].num_args,
877 ) <= (
878 entries[1].family,
879 entries[1].class,
880 &entries[1].name,
881 entries[1].num_args,
882 )
883 }),
884 "inventory must stay deterministically sorted"
885 );
886 }
887
888 #[test]
889 fn test_builtin_function_surface_inventory_classifies_representative_entries() {
890 let abs = find_surface_entry(BuiltinFunctionFamily::Scalar, "abs", 1);
891 assert_eq!(abs.class, BuiltinFunctionClass::CoreScalar);
892 assert!(!abs.is_alias);
893 assert_eq!(abs.surface_id, CORE_FUNCTION_SURFACE_ID);
894
895 let date = find_surface_entry(BuiltinFunctionFamily::Scalar, "date", -1);
896 assert_eq!(date.class, BuiltinFunctionClass::DateTimeScalar);
897 assert!(!date.is_alias);
898 assert_eq!(date.surface_id, CORE_FUNCTION_SURFACE_ID);
899
900 let power = find_surface_entry(BuiltinFunctionFamily::Scalar, "power", 2);
901 assert_eq!(power.class, BuiltinFunctionClass::MathScalar);
902 assert!(power.is_alias);
903 assert_eq!(power.surface_id, CORE_FUNCTION_SURFACE_ID);
904
905 let count = find_surface_entry(BuiltinFunctionFamily::Aggregate, "count", 0);
906 assert_eq!(count.class, BuiltinFunctionClass::Aggregate);
907 assert!(!count.is_alias);
908 assert_eq!(count.surface_id, CORE_FUNCTION_SURFACE_ID);
909
910 let row_number = find_surface_entry(BuiltinFunctionFamily::Window, "row_number", 0);
911 assert_eq!(row_number.class, BuiltinFunctionClass::Window);
912 assert!(!row_number.is_alias);
913 assert_eq!(row_number.surface_id, WINDOW_FUNCTION_SURFACE_ID);
914
915 let string_agg_window = find_surface_entry(BuiltinFunctionFamily::Window, "string_agg", 2);
916 assert_eq!(string_agg_window.class, BuiltinFunctionClass::Window);
917 assert!(string_agg_window.is_alias);
918 assert_eq!(string_agg_window.surface_id, WINDOW_FUNCTION_SURFACE_ID);
919 }
920
921 struct Double;
924
925 impl ScalarFunction for Double {
926 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
927 Ok(SqliteValue::Integer(args[0].to_integer() * 2))
928 }
929
930 fn num_args(&self) -> i32 {
931 1
932 }
933
934 fn name(&self) -> &str {
935 "double"
936 }
937 }
938
939 struct VariadicConcat;
942
943 impl ScalarFunction for VariadicConcat {
944 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
945 let mut out = String::new();
946 for a in args {
947 out.push_str(&a.to_text());
948 }
949 Ok(SqliteValue::Text(out.into()))
950 }
951
952 fn num_args(&self) -> i32 {
953 -1
954 }
955
956 fn min_args(&self) -> i32 {
957 1
958 }
959
960 fn max_args(&self) -> Option<i32> {
961 Some(3)
962 }
963
964 fn name(&self) -> &str {
965 "my_func"
966 }
967 }
968
969 struct TwoArgFunc;
972
973 impl ScalarFunction for TwoArgFunc {
974 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
975 Ok(SqliteValue::Integer(
976 args[0].to_integer() + args[1].to_integer(),
977 ))
978 }
979
980 fn num_args(&self) -> i32 {
981 2
982 }
983
984 fn name(&self) -> &str {
985 "my_func"
986 }
987 }
988
989 fn assert_wrong_arg_count(
990 function: &dyn ScalarFunction,
991 args: &[SqliteValue],
992 expected_name: &str,
993 ) {
994 let err = function.invoke(args).expect_err("wrong arity should fail");
995 let expected = format!("wrong number of arguments to function {expected_name}()");
996 assert!(
997 matches!(&err, FrankenError::FunctionError(message) if message == &expected),
998 "expected {expected:?}, got {err:?}"
999 );
1000 }
1001
1002 fn assert_wrong_arg_count_aggregate(
1003 function: &ErasedAggregateFunction,
1004 args: &[SqliteValue],
1005 expected_name: &str,
1006 ) {
1007 let mut state = function.initial_state();
1008 let err = function
1009 .step(&mut state, args)
1010 .expect_err("wrong aggregate arity should fail");
1011 let expected = format!("wrong number of arguments to function {expected_name}()");
1012 assert!(
1013 matches!(&err, FrankenError::FunctionError(message) if message == &expected),
1014 "expected {expected:?}, got {err:?}"
1015 );
1016 }
1017
1018 fn assert_wrong_arg_count_window(
1019 function: &ErasedWindowFunction,
1020 args: &[SqliteValue],
1021 expected_name: &str,
1022 ) {
1023 let mut state = function.initial_state();
1024 let err = function
1025 .step(&mut state, args)
1026 .expect_err("wrong window arity should fail");
1027 let expected = format!("wrong number of arguments to function {expected_name}()");
1028 assert!(
1029 matches!(&err, FrankenError::FunctionError(message) if message == &expected),
1030 "expected {expected:?}, got {err:?}"
1031 );
1032 }
1033
1034 struct Product;
1035
1036 impl AggregateFunction for Product {
1037 type State = i64;
1038
1039 fn initial_state(&self) -> Self::State {
1040 1
1041 }
1042
1043 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> fsqlite_error::Result<()> {
1044 *state *= args[0].to_integer();
1045 Ok(())
1046 }
1047
1048 fn finalize(&self, state: Self::State) -> fsqlite_error::Result<SqliteValue> {
1049 Ok(SqliteValue::Integer(state))
1050 }
1051
1052 fn num_args(&self) -> i32 {
1053 1
1054 }
1055
1056 fn name(&self) -> &str {
1057 "product"
1058 }
1059 }
1060
1061 struct MovingSum;
1062
1063 impl WindowFunction for MovingSum {
1064 type State = i64;
1065
1066 fn initial_state(&self) -> Self::State {
1067 0
1068 }
1069
1070 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> fsqlite_error::Result<()> {
1071 *state += args[0].to_integer();
1072 Ok(())
1073 }
1074
1075 fn inverse(
1076 &self,
1077 state: &mut Self::State,
1078 args: &[SqliteValue],
1079 ) -> fsqlite_error::Result<()> {
1080 *state -= args[0].to_integer();
1081 Ok(())
1082 }
1083
1084 fn value(&self, state: &Self::State) -> fsqlite_error::Result<SqliteValue> {
1085 Ok(SqliteValue::Integer(*state))
1086 }
1087
1088 fn finalize(&self, state: Self::State) -> fsqlite_error::Result<SqliteValue> {
1089 Ok(SqliteValue::Integer(state))
1090 }
1091
1092 fn num_args(&self) -> i32 {
1093 1
1094 }
1095
1096 fn name(&self) -> &str {
1097 "moving_sum"
1098 }
1099 }
1100
1101 #[test]
1102 fn test_registry_register_scalar() {
1103 let mut registry = FunctionRegistry::new();
1104 let previous = registry.register_scalar(Double);
1105 assert!(previous.is_none());
1106 assert!(registry.contains_scalar("double"));
1107 assert!(registry.contains_scalar("DOUBLE"));
1108 let f = registry
1109 .find_scalar(" Double ", 1)
1110 .expect("double registered");
1111 assert_eq!(
1112 f.invoke(&[SqliteValue::Integer(21)])
1113 .expect("invoke succeeds"),
1114 SqliteValue::Integer(42)
1115 );
1116 }
1117
1118 #[test]
1119 fn test_registry_case_insensitive_lookup() {
1120 let mut registry = FunctionRegistry::new();
1121 registry.register_scalar(Double);
1122
1123 assert!(registry.find_scalar("DOUBLE", 1).is_some());
1125 assert!(registry.find_scalar("Double", 1).is_some());
1126 assert!(registry.find_scalar(" double ", 1).is_some());
1127 }
1128
1129 #[test]
1130 fn test_registry_overwrite() {
1131 let mut registry = FunctionRegistry::new();
1132
1133 let prev = registry.register_scalar(Double);
1135 assert!(prev.is_none());
1136
1137 let prev = registry.register_scalar(Double);
1139 assert!(prev.is_some());
1140
1141 let f = registry.find_scalar("double", 1).unwrap();
1143 assert_eq!(
1144 f.invoke(&[SqliteValue::Integer(5)]).unwrap(),
1145 SqliteValue::Integer(10)
1146 );
1147 }
1148
1149 #[test]
1150 fn test_registry_variadic_fallback() {
1151 let mut registry = FunctionRegistry::new();
1152
1153 registry.register_scalar(VariadicConcat);
1155
1156 let too_few = registry
1157 .find_scalar("my_func", 0)
1158 .expect("known function with bad arity returns erroring scalar");
1159 assert_wrong_arg_count(too_few.as_ref(), &[], "my_func");
1160
1161 let f = registry
1163 .find_scalar("my_func", 3)
1164 .expect("variadic fallback");
1165 assert_eq!(
1166 f.invoke(&[
1167 SqliteValue::Text("a".into()),
1168 SqliteValue::Text("b".into()),
1169 SqliteValue::Text("c".into()),
1170 ])
1171 .unwrap(),
1172 SqliteValue::Text("abc".into())
1173 );
1174 let too_many = registry
1175 .find_scalar("my_func", 4)
1176 .expect("known function with bad arity returns erroring scalar");
1177 assert_wrong_arg_count(
1178 too_many.as_ref(),
1179 &[
1180 SqliteValue::Null,
1181 SqliteValue::Null,
1182 SqliteValue::Null,
1183 SqliteValue::Null,
1184 ],
1185 "my_func",
1186 );
1187 }
1188
1189 #[test]
1190 fn test_registry_exact_wrong_arity_returns_function_error() {
1191 let mut registry = FunctionRegistry::new();
1192 registry.register_scalar(Double);
1193
1194 let f = registry
1195 .find_scalar("double", 2)
1196 .expect("known function with wrong arity returns erroring scalar");
1197 assert_wrong_arg_count(
1198 f.as_ref(),
1199 &[SqliteValue::Integer(1), SqliteValue::Integer(2)],
1200 "double",
1201 );
1202 }
1203
1204 #[test]
1205 fn test_registry_exact_match_over_variadic() {
1206 let mut registry = FunctionRegistry::new();
1207
1208 registry.register_scalar(VariadicConcat);
1210 registry.register_scalar(TwoArgFunc);
1211
1212 let f = registry
1214 .find_scalar("my_func", 2)
1215 .expect("exact match found");
1216 assert_eq!(
1217 f.invoke(&[SqliteValue::Integer(10), SqliteValue::Integer(32)])
1218 .unwrap(),
1219 SqliteValue::Integer(42)
1220 );
1221
1222 let f = registry
1224 .find_scalar("my_func", 3)
1225 .expect("variadic fallback");
1226 assert_eq!(f.num_args(), -1);
1227 }
1228
1229 #[test]
1230 fn test_registry_not_found_returns_none() {
1231 let registry = FunctionRegistry::new();
1232 assert!(registry.find_scalar("nonexistent", 1).is_none());
1233 assert!(registry.find_aggregate("nonexistent", 1).is_none());
1234 assert!(registry.find_window("nonexistent", 1).is_none());
1235 }
1236
1237 #[test]
1238 fn test_registry_register_and_resolve_aggregate() {
1239 let mut registry = FunctionRegistry::new();
1240 let previous = registry.register_aggregate(Product);
1241 assert!(previous.is_none());
1242 assert!(registry.contains_aggregate("product"));
1243 let f = registry
1244 .find_aggregate("PRODUCT", 1)
1245 .expect("product aggregate registered");
1246
1247 let mut state = f.initial_state();
1248 f.step(&mut state, &[SqliteValue::Integer(2)])
1249 .expect("step 1");
1250 f.step(&mut state, &[SqliteValue::Integer(3)])
1251 .expect("step 2");
1252 f.step(&mut state, &[SqliteValue::Integer(7)])
1253 .expect("step 3");
1254
1255 assert_eq!(
1256 f.finalize(state).expect("finalize succeeds"),
1257 SqliteValue::Integer(42)
1258 );
1259 }
1260
1261 #[test]
1262 fn test_registry_aggregate_type_erased() {
1263 let mut registry = FunctionRegistry::new();
1264 registry.register_aggregate(Product);
1265
1266 let f = registry
1268 .find_aggregate("product", 1)
1269 .expect("product found");
1270 let mut state = f.initial_state();
1271 f.step(&mut state, &[SqliteValue::Integer(6)]).unwrap();
1272 f.step(&mut state, &[SqliteValue::Integer(7)]).unwrap();
1273 assert_eq!(f.finalize(state).unwrap(), SqliteValue::Integer(42));
1274 assert_eq!(f.name(), "product");
1275 }
1276
1277 #[test]
1278 fn test_registry_aggregate_wrong_arity_returns_function_error() {
1279 let mut registry = FunctionRegistry::new();
1280 registry.register_aggregate(Product);
1281
1282 let f = registry
1283 .find_aggregate("product", 0)
1284 .expect("known aggregate with wrong arity returns erroring aggregate");
1285 assert_wrong_arg_count_aggregate(f.as_ref(), &[], "product");
1286 }
1287
1288 #[test]
1289 fn test_registry_register_and_resolve_window() {
1290 let mut registry = FunctionRegistry::new();
1291 let previous = registry.register_window(MovingSum);
1292 assert!(previous.is_none());
1293 assert!(registry.contains_window("moving_sum"));
1294 let f = registry
1295 .find_window("MOVING_SUM", 1)
1296 .expect("moving_sum window registered");
1297
1298 let mut state = f.initial_state();
1299 f.step(&mut state, &[SqliteValue::Integer(10)])
1300 .expect("step 1");
1301 f.step(&mut state, &[SqliteValue::Integer(20)])
1302 .expect("step 2");
1303 f.step(&mut state, &[SqliteValue::Integer(30)])
1304 .expect("step 3");
1305 assert_eq!(f.value(&state).expect("value"), SqliteValue::Integer(60));
1306
1307 f.inverse(&mut state, &[SqliteValue::Integer(10)])
1308 .expect("inverse 1");
1309 f.step(&mut state, &[SqliteValue::Integer(40)])
1310 .expect("step 4");
1311 assert_eq!(f.value(&state).expect("value"), SqliteValue::Integer(90));
1312 }
1313
1314 #[test]
1315 fn test_registry_window_wrong_arity_returns_function_error() {
1316 let mut registry = FunctionRegistry::new();
1317 registry.register_window(MovingSum);
1318
1319 let f = registry
1320 .find_window("moving_sum", 0)
1321 .expect("known window with wrong arity returns erroring window");
1322 assert_wrong_arg_count_window(f.as_ref(), &[], "moving_sum");
1323 }
1324
1325 #[test]
1326 fn test_registry_window_accepts_arg_count_reports_known_bad_arity() {
1327 let mut registry = FunctionRegistry::new();
1328 registry.register_window(MovingSum);
1329
1330 assert_eq!(
1331 registry.window_accepts_arg_count("moving_sum", 1),
1332 Some(true)
1333 );
1334 assert_eq!(
1335 registry.window_accepts_arg_count("moving_sum", 0),
1336 Some(false)
1337 );
1338 assert_eq!(registry.window_accepts_arg_count("missing_window", 1), None);
1339 }
1340
1341 #[test]
1342 fn test_registry_window_type_erased() {
1343 let mut registry = FunctionRegistry::new();
1344 registry.register_window(MovingSum);
1345
1346 let f = registry
1347 .find_window("moving_sum", 1)
1348 .expect("moving_sum found");
1349
1350 let mut state = f.initial_state();
1352 f.step(&mut state, &[SqliteValue::Integer(100)]).unwrap();
1353 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(100));
1354
1355 f.inverse(&mut state, &[SqliteValue::Integer(100)]).unwrap();
1356 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(0));
1357
1358 f.step(&mut state, &[SqliteValue::Integer(42)]).unwrap();
1359 assert_eq!(f.finalize(state).unwrap(), SqliteValue::Integer(42));
1360 }
1361
1362 #[test]
1363 fn test_function_key_equality() {
1364 let k1 = FunctionKey::new("ABS", 1);
1365 let k2 = FunctionKey::new("abs", 1);
1366 let k3 = FunctionKey::new("ABS", 2);
1367
1368 assert_eq!(k1, k2, "case-insensitive equality");
1369 assert_ne!(k1, k3, "different num_args");
1370 }
1371
1372 #[test]
1375 fn test_e2e_custom_collation_in_order_by() {
1376 use collation::{BinaryCollation, CollationFunction, NoCaseCollation, RtrimCollation};
1377
1378 struct ReverseAlpha;
1380
1381 impl CollationFunction for ReverseAlpha {
1382 fn name(&self) -> &str {
1383 "REVERSE_ALPHA"
1384 }
1385
1386 fn compare(&self, left: &[u8], right: &[u8]) -> std::cmp::Ordering {
1387 right.cmp(left)
1389 }
1390 }
1391
1392 let coll = ReverseAlpha;
1393 let mut data: Vec<&[u8]> = vec![b"banana", b"apple", b"cherry", b"date"];
1394 data.sort_by(|a, b| coll.compare(a, b));
1395
1396 let expected: Vec<&[u8]> = vec![b"date", b"cherry", b"banana", b"apple"];
1398 assert_eq!(data, expected);
1399 assert_eq!(coll.name(), "REVERSE_ALPHA");
1400
1401 let collations: Vec<Box<dyn CollationFunction>> = vec![
1403 Box::new(BinaryCollation),
1404 Box::new(NoCaseCollation),
1405 Box::new(RtrimCollation),
1406 Box::new(ReverseAlpha),
1407 ];
1408 assert_eq!(collations.len(), 4);
1409
1410 let mut binary_sorted = data.clone();
1412 binary_sorted.sort_by(|a, b| collations[0].compare(a, b));
1413 assert_eq!(binary_sorted[0], b"apple");
1414 }
1415
1416 #[test]
1417 fn test_e2e_authorizer_sandboxing() {
1418 use authorizer::{AuthAction, AuthResult, Authorizer};
1419
1420 struct SelectOnlyAuthorizer;
1422
1423 impl Authorizer for SelectOnlyAuthorizer {
1424 fn authorize(
1425 &self,
1426 action: AuthAction,
1427 _arg1: Option<&str>,
1428 arg2: Option<&str>,
1429 _db_name: Option<&str>,
1430 _trigger: Option<&str>,
1431 ) -> AuthResult {
1432 match action {
1433 AuthAction::Select | AuthAction::Read => {
1434 if action == AuthAction::Read && arg2 == Some("secret") {
1436 return AuthResult::Ignore;
1437 }
1438 AuthResult::Ok
1439 }
1440 AuthAction::Insert | AuthAction::Update | AuthAction::Delete => {
1441 AuthResult::Deny
1442 }
1443 _ => AuthResult::Ok,
1444 }
1445 }
1446 }
1447
1448 let auth = SelectOnlyAuthorizer;
1449
1450 assert_eq!(
1452 auth.authorize(AuthAction::Select, None, None, Some("main"), None),
1453 AuthResult::Ok,
1454 "SELECT must be allowed"
1455 );
1456
1457 assert_eq!(
1459 auth.authorize(AuthAction::Insert, Some("users"), None, Some("main"), None),
1460 AuthResult::Deny,
1461 "INSERT must be denied (compile-time auth error)"
1462 );
1463
1464 assert_eq!(
1466 auth.authorize(
1467 AuthAction::Update,
1468 Some("users"),
1469 Some("email"),
1470 Some("main"),
1471 None
1472 ),
1473 AuthResult::Deny,
1474 );
1475
1476 assert_eq!(
1478 auth.authorize(AuthAction::Delete, Some("users"), None, Some("main"), None),
1479 AuthResult::Deny,
1480 );
1481
1482 assert_eq!(
1484 auth.authorize(
1485 AuthAction::Read,
1486 Some("users"),
1487 Some("secret"),
1488 Some("main"),
1489 None
1490 ),
1491 AuthResult::Ignore,
1492 "Ignore must nullify column"
1493 );
1494
1495 assert_eq!(
1497 auth.authorize(
1498 AuthAction::Read,
1499 Some("users"),
1500 Some("name"),
1501 Some("main"),
1502 None
1503 ),
1504 AuthResult::Ok,
1505 );
1506 }
1507
1508 #[test]
1509 fn test_e2e_function_registry_resolution() {
1510 struct Abs1;
1512
1513 impl ScalarFunction for Abs1 {
1514 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
1515 Ok(SqliteValue::Integer(args[0].to_integer().abs()))
1516 }
1517
1518 fn num_args(&self) -> i32 {
1519 1
1520 }
1521
1522 fn name(&self) -> &str {
1523 "abs"
1524 }
1525 }
1526
1527 struct AbsVariadic;
1528
1529 impl ScalarFunction for AbsVariadic {
1530 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
1531 let sum: i64 = args.iter().map(|a| a.to_integer().abs()).sum();
1533 Ok(SqliteValue::Integer(sum))
1534 }
1535
1536 fn num_args(&self) -> i32 {
1537 -1
1538 }
1539
1540 fn name(&self) -> &str {
1541 "abs"
1542 }
1543 }
1544
1545 let mut registry = FunctionRegistry::new();
1546 registry.register_scalar(Abs1);
1547 registry.register_scalar(AbsVariadic);
1548
1549 let f = registry.find_scalar("abs", 1).expect("abs(1) found");
1551 assert_eq!(f.num_args(), 1, "exact 1-arg match");
1552 assert_eq!(
1553 f.invoke(&[SqliteValue::Integer(-5)]).unwrap(),
1554 SqliteValue::Integer(5)
1555 );
1556
1557 let f = registry.find_scalar("abs", 2).expect("abs variadic found");
1559 assert_eq!(f.num_args(), -1, "variadic fallback for 2 args");
1560 assert_eq!(
1561 f.invoke(&[SqliteValue::Integer(-5), SqliteValue::Integer(-3)])
1562 .unwrap(),
1563 SqliteValue::Integer(8)
1564 );
1565
1566 assert!(registry.find_scalar("nonexistent", 1).is_none());
1568 }
1569
1570 #[test]
1571 fn test_authorizer_called_at_compile_time() {
1572 use authorizer::{AuthAction, AuthResult, Authorizer};
1573 use std::sync::Mutex;
1574
1575 struct TrackingAuthorizer {
1577 calls: Mutex<Vec<AuthAction>>,
1578 }
1579
1580 impl TrackingAuthorizer {
1581 fn new() -> Self {
1582 Self {
1583 calls: Mutex::new(Vec::new()),
1584 }
1585 }
1586 }
1587
1588 impl Authorizer for TrackingAuthorizer {
1589 fn authorize(
1590 &self,
1591 action: AuthAction,
1592 _arg1: Option<&str>,
1593 _arg2: Option<&str>,
1594 _db_name: Option<&str>,
1595 _trigger: Option<&str>,
1596 ) -> AuthResult {
1597 self.calls.lock().unwrap().push(action);
1598 AuthResult::Ok
1599 }
1600 }
1601
1602 let auth = TrackingAuthorizer::new();
1603
1604 auth.authorize(AuthAction::Select, None, None, Some("main"), None);
1616 auth.authorize(
1617 AuthAction::Read,
1618 Some("users"),
1619 Some("name"),
1620 Some("main"),
1621 None,
1622 );
1623 auth.authorize(
1624 AuthAction::Read,
1625 Some("users"),
1626 Some("email"),
1627 Some("main"),
1628 None,
1629 );
1630 auth.authorize(
1631 AuthAction::Read,
1632 Some("users"),
1633 Some("id"),
1634 Some("main"),
1635 None,
1636 );
1637
1638 let calls = auth.calls.lock().unwrap();
1639 assert_eq!(calls.len(), 4, "authorizer called 4 times during prepare");
1640 assert_eq!(calls[0], AuthAction::Select);
1641 assert_eq!(calls[1], AuthAction::Read);
1642 assert_eq!(calls[2], AuthAction::Read);
1643 assert_eq!(calls[3], AuthAction::Read);
1644 drop(calls);
1645
1646 let calls_after = auth.calls.lock().unwrap();
1650 assert_eq!(
1651 calls_after.len(),
1652 4,
1653 "authorizer must NOT be called during step/execution"
1654 );
1655 drop(calls_after);
1656 }
1657}