1#![allow(clippy::unnecessary_literal_bound)]
13
14use std::any::Any;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU64, Ordering};
18
19use tracing::debug;
20
21static FSQLITE_FUNC_CALLS_TOTAL: AtomicU64 = AtomicU64::new(0);
25static FSQLITE_FUNC_EVAL_DURATION_US_TOTAL: AtomicU64 = AtomicU64::new(0);
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct FuncMetricsSnapshot {
31 pub calls_total: u64,
33 pub eval_duration_us_total: u64,
35}
36
37#[must_use]
39pub fn func_metrics_snapshot() -> FuncMetricsSnapshot {
40 FuncMetricsSnapshot {
41 calls_total: FSQLITE_FUNC_CALLS_TOTAL.load(Ordering::Relaxed),
42 eval_duration_us_total: FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.load(Ordering::Relaxed),
43 }
44}
45
46pub fn reset_func_metrics() {
48 FSQLITE_FUNC_CALLS_TOTAL.store(0, Ordering::Relaxed);
49 FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.store(0, Ordering::Relaxed);
50}
51
52pub fn record_func_call(duration_us: u64) {
54 FSQLITE_FUNC_CALLS_TOTAL.fetch_add(1, Ordering::Relaxed);
55 FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.fetch_add(duration_us, Ordering::Relaxed);
56}
57
58static FSQLITE_UDF_REGISTERED: AtomicU64 = AtomicU64::new(0);
62
63pub fn record_udf_registered() {
65 FSQLITE_UDF_REGISTERED.fetch_add(1, Ordering::Relaxed);
66}
67
68#[must_use]
70pub fn udf_registered_count() -> u64 {
71 FSQLITE_UDF_REGISTERED.load(Ordering::Relaxed)
72}
73
74pub fn reset_udf_metrics() {
76 FSQLITE_UDF_REGISTERED.store(0, Ordering::Relaxed);
77}
78
79pub mod agg_builtins;
80pub mod aggregate;
81pub mod authorizer;
82pub mod builtins;
83pub mod collation;
84pub mod datetime;
85pub mod math;
86pub mod scalar;
87pub mod vtab;
88pub mod window;
89pub mod window_builtins;
90
91pub use agg_builtins::register_aggregate_builtins;
92pub use aggregate::{AggregateAdapter, AggregateFunction};
93pub use authorizer::{AuthAction, AuthResult, Authorizer, AuthorizerAction, AuthorizerDecision};
94pub use builtins::{
95 get_last_changes, get_last_insert_rowid, register_builtins, set_last_changes,
96 set_last_insert_rowid,
97};
98pub use collation::{
99 BinaryCollation, CollationAnnotation, CollationFunction, CollationRegistry, CollationSource,
100 NoCaseCollation, RtrimCollation, resolve_collation,
101};
102pub use datetime::register_datetime_builtins;
103pub use math::register_math_builtins;
104pub use scalar::ScalarFunction;
105pub use vtab::{
106 ColumnContext, ConstraintOp, IndexConstraint, IndexConstraintUsage, IndexInfo, IndexOrderBy,
107 VirtualTable, VirtualTableCursor,
108};
109pub use window::{WindowAdapter, WindowFunction};
110pub use window_builtins::register_window_builtins;
111
112pub type ErasedAggregateFunction = dyn AggregateFunction<State = Box<dyn Any + Send>>;
114
115pub type ErasedWindowFunction = dyn WindowFunction<State = Box<dyn Any + Send>>;
117
118#[derive(Debug, Clone, Hash, Eq, PartialEq)]
123pub struct FunctionKey {
124 pub name: String,
126 pub num_args: i32,
128}
129
130impl FunctionKey {
131 #[must_use]
133 pub fn new(name: &str, num_args: i32) -> Self {
134 Self {
135 name: canonical_name(name),
136 num_args,
137 }
138 }
139}
140
141#[derive(Default)]
149pub struct FunctionRegistry {
150 scalars: HashMap<FunctionKey, Arc<dyn ScalarFunction>>,
151 aggregates: HashMap<FunctionKey, Arc<ErasedAggregateFunction>>,
152 windows: HashMap<FunctionKey, Arc<ErasedWindowFunction>>,
153}
154
155impl FunctionRegistry {
156 #[must_use]
158 pub fn new() -> Self {
159 Self::default()
160 }
161
162 #[must_use]
167 pub fn clone_from_arc(arc: &Arc<Self>) -> Self {
168 Self {
169 scalars: arc.scalars.clone(),
170 aggregates: arc.aggregates.clone(),
171 windows: arc.windows.clone(),
172 }
173 }
174
175 pub fn register_scalar<F>(&mut self, function: F) -> Option<Arc<dyn ScalarFunction>>
180 where
181 F: ScalarFunction + 'static,
182 {
183 let key = FunctionKey::new(function.name(), function.num_args());
184 self.scalars.insert(key, Arc::new(function))
185 }
186
187 pub fn register_aggregate<F>(&mut self, function: F) -> Option<Arc<ErasedAggregateFunction>>
191 where
192 F: AggregateFunction + 'static,
193 F::State: 'static,
194 {
195 let key = FunctionKey::new(function.name(), function.num_args());
196 self.aggregates
197 .insert(key, Arc::new(AggregateAdapter::new(function)))
198 }
199
200 pub fn register_window<F>(&mut self, function: F) -> Option<Arc<ErasedWindowFunction>>
204 where
205 F: WindowFunction + 'static,
206 F::State: 'static,
207 {
208 let key = FunctionKey::new(function.name(), function.num_args());
209 self.windows
210 .insert(key, Arc::new(WindowAdapter::new(function)))
211 }
212
213 #[must_use]
218 pub fn find_scalar(&self, name: &str, num_args: i32) -> Option<Arc<dyn ScalarFunction>> {
219 let canon = canonical_name(name);
220 let exact = FunctionKey {
221 name: canon.clone(),
222 num_args,
223 };
224 if let Some(f) = self.scalars.get(&exact) {
225 debug!(name = %canon, arity = num_args, kind = "scalar", hit = "exact", "registry lookup");
226 return Some(Arc::clone(f));
227 }
228 let variadic = FunctionKey {
230 name: canon.clone(),
231 num_args: -1,
232 };
233 let result = self.scalars.get(&variadic).map(Arc::clone);
234 debug!(
235 name = %canon,
236 arity = num_args,
237 kind = "scalar",
238 hit = if result.is_some() { "variadic" } else { "miss" },
239 "registry lookup"
240 );
241 result
242 }
243
244 #[must_use]
248 pub fn find_aggregate(
249 &self,
250 name: &str,
251 num_args: i32,
252 ) -> Option<Arc<ErasedAggregateFunction>> {
253 let canon = canonical_name(name);
254 let exact = FunctionKey {
255 name: canon.clone(),
256 num_args,
257 };
258 if let Some(f) = self.aggregates.get(&exact) {
259 debug!(name = %canon, arity = num_args, kind = "aggregate", hit = "exact", "registry lookup");
260 return Some(Arc::clone(f));
261 }
262 let variadic = FunctionKey {
263 name: canon.clone(),
264 num_args: -1,
265 };
266 let result = self.aggregates.get(&variadic).map(Arc::clone);
267 debug!(
268 name = %canon,
269 arity = num_args,
270 kind = "aggregate",
271 hit = if result.is_some() { "variadic" } else { "miss" },
272 "registry lookup"
273 );
274 result
275 }
276
277 #[must_use]
281 pub fn find_window(&self, name: &str, num_args: i32) -> Option<Arc<ErasedWindowFunction>> {
282 let canon = canonical_name(name);
283 let exact = FunctionKey {
284 name: canon.clone(),
285 num_args,
286 };
287 if let Some(f) = self.windows.get(&exact) {
288 debug!(name = %canon, arity = num_args, kind = "window", hit = "exact", "registry lookup");
289 return Some(Arc::clone(f));
290 }
291 let variadic = FunctionKey {
292 name: canon.clone(),
293 num_args: -1,
294 };
295 let result = self.windows.get(&variadic).map(Arc::clone);
296 debug!(
297 name = %canon,
298 arity = num_args,
299 kind = "window",
300 hit = if result.is_some() { "variadic" } else { "miss" },
301 "registry lookup"
302 );
303 result
304 }
305
306 #[must_use]
309 pub fn contains_scalar(&self, name: &str) -> bool {
310 let canon = canonical_name(name);
311 self.scalars.keys().any(|k| k.name == canon)
312 }
313
314 #[must_use]
317 pub fn contains_aggregate(&self, name: &str) -> bool {
318 let canon = canonical_name(name);
319 self.aggregates.keys().any(|k| k.name == canon)
320 }
321
322 #[must_use]
325 pub fn contains_window(&self, name: &str) -> bool {
326 let canon = canonical_name(name);
327 self.windows.keys().any(|k| k.name == canon)
328 }
329
330 #[must_use]
334 pub fn aggregate_names_lowercase(&self) -> Vec<String> {
335 let mut names: Vec<String> = self
336 .aggregates
337 .keys()
338 .map(|k| k.name.to_ascii_lowercase())
339 .collect();
340 names.sort();
341 names.dedup();
342 names
343 }
344}
345
346fn canonical_name(name: &str) -> String {
347 name.trim().to_ascii_uppercase()
348}
349
350#[cfg(test)]
351mod tests {
352 use fsqlite_types::SqliteValue;
353
354 use super::*;
355
356 struct Double;
359
360 impl ScalarFunction for Double {
361 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
362 Ok(SqliteValue::Integer(args[0].to_integer() * 2))
363 }
364
365 fn num_args(&self) -> i32 {
366 1
367 }
368
369 fn name(&self) -> &str {
370 "double"
371 }
372 }
373
374 struct VariadicConcat;
377
378 impl ScalarFunction for VariadicConcat {
379 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
380 let mut out = String::new();
381 for a in args {
382 out.push_str(&a.to_text());
383 }
384 Ok(SqliteValue::Text(out))
385 }
386
387 fn num_args(&self) -> i32 {
388 -1
389 }
390
391 fn name(&self) -> &str {
392 "my_func"
393 }
394 }
395
396 struct TwoArgFunc;
399
400 impl ScalarFunction for TwoArgFunc {
401 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
402 Ok(SqliteValue::Integer(
403 args[0].to_integer() + args[1].to_integer(),
404 ))
405 }
406
407 fn num_args(&self) -> i32 {
408 2
409 }
410
411 fn name(&self) -> &str {
412 "my_func"
413 }
414 }
415
416 struct Product;
417
418 impl AggregateFunction for Product {
419 type State = i64;
420
421 fn initial_state(&self) -> Self::State {
422 1
423 }
424
425 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> fsqlite_error::Result<()> {
426 *state *= args[0].to_integer();
427 Ok(())
428 }
429
430 fn finalize(&self, state: Self::State) -> fsqlite_error::Result<SqliteValue> {
431 Ok(SqliteValue::Integer(state))
432 }
433
434 fn num_args(&self) -> i32 {
435 1
436 }
437
438 fn name(&self) -> &str {
439 "product"
440 }
441 }
442
443 struct MovingSum;
444
445 impl WindowFunction for MovingSum {
446 type State = i64;
447
448 fn initial_state(&self) -> Self::State {
449 0
450 }
451
452 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> fsqlite_error::Result<()> {
453 *state += args[0].to_integer();
454 Ok(())
455 }
456
457 fn inverse(
458 &self,
459 state: &mut Self::State,
460 args: &[SqliteValue],
461 ) -> fsqlite_error::Result<()> {
462 *state -= args[0].to_integer();
463 Ok(())
464 }
465
466 fn value(&self, state: &Self::State) -> fsqlite_error::Result<SqliteValue> {
467 Ok(SqliteValue::Integer(*state))
468 }
469
470 fn finalize(&self, state: Self::State) -> fsqlite_error::Result<SqliteValue> {
471 Ok(SqliteValue::Integer(state))
472 }
473
474 fn num_args(&self) -> i32 {
475 1
476 }
477
478 fn name(&self) -> &str {
479 "moving_sum"
480 }
481 }
482
483 #[test]
484 fn test_registry_register_scalar() {
485 let mut registry = FunctionRegistry::new();
486 let previous = registry.register_scalar(Double);
487 assert!(previous.is_none());
488 assert!(registry.contains_scalar("double"));
489 assert!(registry.contains_scalar("DOUBLE"));
490 let f = registry
491 .find_scalar(" Double ", 1)
492 .expect("double registered");
493 assert_eq!(
494 f.invoke(&[SqliteValue::Integer(21)])
495 .expect("invoke succeeds"),
496 SqliteValue::Integer(42)
497 );
498 }
499
500 #[test]
501 fn test_registry_case_insensitive_lookup() {
502 let mut registry = FunctionRegistry::new();
503 registry.register_scalar(Double);
504
505 assert!(registry.find_scalar("DOUBLE", 1).is_some());
507 assert!(registry.find_scalar("Double", 1).is_some());
508 assert!(registry.find_scalar(" double ", 1).is_some());
509 }
510
511 #[test]
512 fn test_registry_overwrite() {
513 let mut registry = FunctionRegistry::new();
514
515 let prev = registry.register_scalar(Double);
517 assert!(prev.is_none());
518
519 let prev = registry.register_scalar(Double);
521 assert!(prev.is_some());
522
523 let f = registry.find_scalar("double", 1).unwrap();
525 assert_eq!(
526 f.invoke(&[SqliteValue::Integer(5)]).unwrap(),
527 SqliteValue::Integer(10)
528 );
529 }
530
531 #[test]
532 fn test_registry_variadic_fallback() {
533 let mut registry = FunctionRegistry::new();
534
535 registry.register_scalar(VariadicConcat);
537
538 let f = registry
540 .find_scalar("my_func", 3)
541 .expect("variadic fallback");
542 assert_eq!(
543 f.invoke(&[
544 SqliteValue::Text("a".to_owned()),
545 SqliteValue::Text("b".to_owned()),
546 SqliteValue::Text("c".to_owned()),
547 ])
548 .unwrap(),
549 SqliteValue::Text("abc".to_owned())
550 );
551 }
552
553 #[test]
554 fn test_registry_exact_match_over_variadic() {
555 let mut registry = FunctionRegistry::new();
556
557 registry.register_scalar(VariadicConcat);
559 registry.register_scalar(TwoArgFunc);
560
561 let f = registry
563 .find_scalar("my_func", 2)
564 .expect("exact match found");
565 assert_eq!(
566 f.invoke(&[SqliteValue::Integer(10), SqliteValue::Integer(32)])
567 .unwrap(),
568 SqliteValue::Integer(42)
569 );
570
571 let f = registry
573 .find_scalar("my_func", 5)
574 .expect("variadic fallback");
575 assert_eq!(f.num_args(), -1);
576 }
577
578 #[test]
579 fn test_registry_not_found_returns_none() {
580 let registry = FunctionRegistry::new();
581 assert!(registry.find_scalar("nonexistent", 1).is_none());
582 assert!(registry.find_aggregate("nonexistent", 1).is_none());
583 assert!(registry.find_window("nonexistent", 1).is_none());
584 }
585
586 #[test]
587 fn test_registry_register_and_resolve_aggregate() {
588 let mut registry = FunctionRegistry::new();
589 let previous = registry.register_aggregate(Product);
590 assert!(previous.is_none());
591 assert!(registry.contains_aggregate("product"));
592 let f = registry
593 .find_aggregate("PRODUCT", 1)
594 .expect("product aggregate registered");
595
596 let mut state = f.initial_state();
597 f.step(&mut state, &[SqliteValue::Integer(2)])
598 .expect("step 1");
599 f.step(&mut state, &[SqliteValue::Integer(3)])
600 .expect("step 2");
601 f.step(&mut state, &[SqliteValue::Integer(7)])
602 .expect("step 3");
603
604 assert_eq!(
605 f.finalize(state).expect("finalize succeeds"),
606 SqliteValue::Integer(42)
607 );
608 }
609
610 #[test]
611 fn test_registry_aggregate_type_erased() {
612 let mut registry = FunctionRegistry::new();
613 registry.register_aggregate(Product);
614
615 let f = registry
617 .find_aggregate("product", 1)
618 .expect("product found");
619 let mut state = f.initial_state();
620 f.step(&mut state, &[SqliteValue::Integer(6)]).unwrap();
621 f.step(&mut state, &[SqliteValue::Integer(7)]).unwrap();
622 assert_eq!(f.finalize(state).unwrap(), SqliteValue::Integer(42));
623 assert_eq!(f.name(), "product");
624 }
625
626 #[test]
627 fn test_registry_register_and_resolve_window() {
628 let mut registry = FunctionRegistry::new();
629 let previous = registry.register_window(MovingSum);
630 assert!(previous.is_none());
631 assert!(registry.contains_window("moving_sum"));
632 let f = registry
633 .find_window("MOVING_SUM", 1)
634 .expect("moving_sum window registered");
635
636 let mut state = f.initial_state();
637 f.step(&mut state, &[SqliteValue::Integer(10)])
638 .expect("step 1");
639 f.step(&mut state, &[SqliteValue::Integer(20)])
640 .expect("step 2");
641 f.step(&mut state, &[SqliteValue::Integer(30)])
642 .expect("step 3");
643 assert_eq!(f.value(&state).expect("value"), SqliteValue::Integer(60));
644
645 f.inverse(&mut state, &[SqliteValue::Integer(10)])
646 .expect("inverse 1");
647 f.step(&mut state, &[SqliteValue::Integer(40)])
648 .expect("step 4");
649 assert_eq!(f.value(&state).expect("value"), SqliteValue::Integer(90));
650 }
651
652 #[test]
653 fn test_registry_window_type_erased() {
654 let mut registry = FunctionRegistry::new();
655 registry.register_window(MovingSum);
656
657 let f = registry
658 .find_window("moving_sum", 1)
659 .expect("moving_sum found");
660
661 let mut state = f.initial_state();
663 f.step(&mut state, &[SqliteValue::Integer(100)]).unwrap();
664 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(100));
665
666 f.inverse(&mut state, &[SqliteValue::Integer(100)]).unwrap();
667 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(0));
668
669 f.step(&mut state, &[SqliteValue::Integer(42)]).unwrap();
670 assert_eq!(f.finalize(state).unwrap(), SqliteValue::Integer(42));
671 }
672
673 #[test]
674 fn test_function_key_equality() {
675 let k1 = FunctionKey::new("ABS", 1);
676 let k2 = FunctionKey::new("abs", 1);
677 let k3 = FunctionKey::new("ABS", 2);
678
679 assert_eq!(k1, k2, "case-insensitive equality");
680 assert_ne!(k1, k3, "different num_args");
681 }
682
683 #[test]
686 fn test_e2e_custom_collation_in_order_by() {
687 use collation::{BinaryCollation, CollationFunction, NoCaseCollation, RtrimCollation};
688
689 struct ReverseAlpha;
691
692 impl CollationFunction for ReverseAlpha {
693 fn name(&self) -> &str {
694 "REVERSE_ALPHA"
695 }
696
697 fn compare(&self, left: &[u8], right: &[u8]) -> std::cmp::Ordering {
698 right.cmp(left)
700 }
701 }
702
703 let coll = ReverseAlpha;
704 let mut data: Vec<&[u8]> = vec![b"banana", b"apple", b"cherry", b"date"];
705 data.sort_by(|a, b| coll.compare(a, b));
706
707 let expected: Vec<&[u8]> = vec![b"date", b"cherry", b"banana", b"apple"];
709 assert_eq!(data, expected);
710 assert_eq!(coll.name(), "REVERSE_ALPHA");
711
712 let collations: Vec<Box<dyn CollationFunction>> = vec![
714 Box::new(BinaryCollation),
715 Box::new(NoCaseCollation),
716 Box::new(RtrimCollation),
717 Box::new(ReverseAlpha),
718 ];
719 assert_eq!(collations.len(), 4);
720
721 let mut binary_sorted = data.clone();
723 binary_sorted.sort_by(|a, b| collations[0].compare(a, b));
724 assert_eq!(binary_sorted[0], b"apple");
725 }
726
727 #[test]
728 fn test_e2e_authorizer_sandboxing() {
729 use authorizer::{AuthAction, AuthResult, Authorizer};
730
731 struct SelectOnlyAuthorizer;
733
734 impl Authorizer for SelectOnlyAuthorizer {
735 fn authorize(
736 &self,
737 action: AuthAction,
738 _arg1: Option<&str>,
739 arg2: Option<&str>,
740 _db_name: Option<&str>,
741 _trigger: Option<&str>,
742 ) -> AuthResult {
743 match action {
744 AuthAction::Select | AuthAction::Read => {
745 if action == AuthAction::Read && arg2 == Some("secret") {
747 return AuthResult::Ignore;
748 }
749 AuthResult::Ok
750 }
751 AuthAction::Insert | AuthAction::Update | AuthAction::Delete => {
752 AuthResult::Deny
753 }
754 _ => AuthResult::Ok,
755 }
756 }
757 }
758
759 let auth = SelectOnlyAuthorizer;
760
761 assert_eq!(
763 auth.authorize(AuthAction::Select, None, None, Some("main"), None),
764 AuthResult::Ok,
765 "SELECT must be allowed"
766 );
767
768 assert_eq!(
770 auth.authorize(AuthAction::Insert, Some("users"), None, Some("main"), None),
771 AuthResult::Deny,
772 "INSERT must be denied (compile-time auth error)"
773 );
774
775 assert_eq!(
777 auth.authorize(
778 AuthAction::Update,
779 Some("users"),
780 Some("email"),
781 Some("main"),
782 None
783 ),
784 AuthResult::Deny,
785 );
786
787 assert_eq!(
789 auth.authorize(AuthAction::Delete, Some("users"), None, Some("main"), None),
790 AuthResult::Deny,
791 );
792
793 assert_eq!(
795 auth.authorize(
796 AuthAction::Read,
797 Some("users"),
798 Some("secret"),
799 Some("main"),
800 None
801 ),
802 AuthResult::Ignore,
803 "Ignore must nullify column"
804 );
805
806 assert_eq!(
808 auth.authorize(
809 AuthAction::Read,
810 Some("users"),
811 Some("name"),
812 Some("main"),
813 None
814 ),
815 AuthResult::Ok,
816 );
817 }
818
819 #[test]
820 fn test_e2e_function_registry_resolution() {
821 struct Abs1;
823
824 impl ScalarFunction for Abs1 {
825 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
826 Ok(SqliteValue::Integer(args[0].to_integer().abs()))
827 }
828
829 fn num_args(&self) -> i32 {
830 1
831 }
832
833 fn name(&self) -> &str {
834 "abs"
835 }
836 }
837
838 struct AbsVariadic;
839
840 impl ScalarFunction for AbsVariadic {
841 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
842 let sum: i64 = args.iter().map(|a| a.to_integer().abs()).sum();
844 Ok(SqliteValue::Integer(sum))
845 }
846
847 fn num_args(&self) -> i32 {
848 -1
849 }
850
851 fn name(&self) -> &str {
852 "abs"
853 }
854 }
855
856 let mut registry = FunctionRegistry::new();
857 registry.register_scalar(Abs1);
858 registry.register_scalar(AbsVariadic);
859
860 let f = registry.find_scalar("abs", 1).expect("abs(1) found");
862 assert_eq!(f.num_args(), 1, "exact 1-arg match");
863 assert_eq!(
864 f.invoke(&[SqliteValue::Integer(-5)]).unwrap(),
865 SqliteValue::Integer(5)
866 );
867
868 let f = registry.find_scalar("abs", 2).expect("abs variadic found");
870 assert_eq!(f.num_args(), -1, "variadic fallback for 2 args");
871 assert_eq!(
872 f.invoke(&[SqliteValue::Integer(-5), SqliteValue::Integer(-3)])
873 .unwrap(),
874 SqliteValue::Integer(8)
875 );
876
877 assert!(registry.find_scalar("nonexistent", 1).is_none());
879 }
880
881 #[test]
882 fn test_authorizer_called_at_compile_time() {
883 use authorizer::{AuthAction, AuthResult, Authorizer};
884 use std::sync::Mutex;
885
886 struct TrackingAuthorizer {
888 calls: Mutex<Vec<AuthAction>>,
889 }
890
891 impl TrackingAuthorizer {
892 fn new() -> Self {
893 Self {
894 calls: Mutex::new(Vec::new()),
895 }
896 }
897 }
898
899 impl Authorizer for TrackingAuthorizer {
900 fn authorize(
901 &self,
902 action: AuthAction,
903 _arg1: Option<&str>,
904 _arg2: Option<&str>,
905 _db_name: Option<&str>,
906 _trigger: Option<&str>,
907 ) -> AuthResult {
908 self.calls.lock().unwrap().push(action);
909 AuthResult::Ok
910 }
911 }
912
913 let auth = TrackingAuthorizer::new();
914
915 auth.authorize(AuthAction::Select, None, None, Some("main"), None);
927 auth.authorize(
928 AuthAction::Read,
929 Some("users"),
930 Some("name"),
931 Some("main"),
932 None,
933 );
934 auth.authorize(
935 AuthAction::Read,
936 Some("users"),
937 Some("email"),
938 Some("main"),
939 None,
940 );
941 auth.authorize(
942 AuthAction::Read,
943 Some("users"),
944 Some("id"),
945 Some("main"),
946 None,
947 );
948
949 let calls = auth.calls.lock().unwrap();
950 assert_eq!(calls.len(), 4, "authorizer called 4 times during prepare");
951 assert_eq!(calls[0], AuthAction::Select);
952 assert_eq!(calls[1], AuthAction::Read);
953 assert_eq!(calls[2], AuthAction::Read);
954 assert_eq!(calls[3], AuthAction::Read);
955 drop(calls);
956
957 let calls_after = auth.calls.lock().unwrap();
961 assert_eq!(
962 calls_after.len(),
963 4,
964 "authorizer must NOT be called during step/execution"
965 );
966 drop(calls_after);
967 }
968}