1#![allow(clippy::unnecessary_literal_bound)]
13
14use std::any::Any;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU64, Ordering};
18
19use tracing::debug;
20
21static FSQLITE_FUNC_CALLS_TOTAL: AtomicU64 = AtomicU64::new(0);
25static FSQLITE_FUNC_EVAL_DURATION_US_TOTAL: AtomicU64 = AtomicU64::new(0);
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct FuncMetricsSnapshot {
31 pub calls_total: u64,
33 pub eval_duration_us_total: u64,
35}
36
37#[must_use]
39pub fn func_metrics_snapshot() -> FuncMetricsSnapshot {
40 FuncMetricsSnapshot {
41 calls_total: FSQLITE_FUNC_CALLS_TOTAL.load(Ordering::Relaxed),
42 eval_duration_us_total: FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.load(Ordering::Relaxed),
43 }
44}
45
46pub fn reset_func_metrics() {
48 FSQLITE_FUNC_CALLS_TOTAL.store(0, Ordering::Relaxed);
49 FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.store(0, Ordering::Relaxed);
50}
51
52pub fn record_func_call(duration_us: u64) {
54 FSQLITE_FUNC_CALLS_TOTAL.fetch_add(1, Ordering::Relaxed);
55 FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.fetch_add(duration_us, Ordering::Relaxed);
56}
57
58pub fn record_func_call_count_only() {
60 FSQLITE_FUNC_CALLS_TOTAL.fetch_add(1, Ordering::Relaxed);
61}
62
63static FSQLITE_UDF_REGISTERED: AtomicU64 = AtomicU64::new(0);
67
68pub fn record_udf_registered() {
70 FSQLITE_UDF_REGISTERED.fetch_add(1, Ordering::Relaxed);
71}
72
73#[must_use]
75pub fn udf_registered_count() -> u64 {
76 FSQLITE_UDF_REGISTERED.load(Ordering::Relaxed)
77}
78
79pub fn reset_udf_metrics() {
81 FSQLITE_UDF_REGISTERED.store(0, Ordering::Relaxed);
82}
83
84pub mod agg_builtins;
85pub mod aggregate;
86pub mod authorizer;
87pub mod builtins;
88pub mod collation;
89pub mod datetime;
90pub mod math;
91pub mod scalar;
92pub mod vtab;
93pub mod window;
94pub mod window_builtins;
95
96pub use agg_builtins::register_aggregate_builtins;
97pub use aggregate::{AggregateAdapter, AggregateFunction};
98pub use authorizer::{AuthAction, AuthResult, Authorizer, AuthorizerAction, AuthorizerDecision};
99pub use builtins::{
100 ChangeTrackingState, get_last_changes, get_last_insert_rowid, get_total_changes,
101 register_builtins, reset_total_changes, set_change_tracking_state, set_last_changes,
102 set_last_insert_rowid,
103};
104pub use collation::{
105 BinaryCollation, CollationAnnotation, CollationFunction, CollationRegistry, CollationSource,
106 NoCaseCollation, RtrimCollation, resolve_collation,
107};
108pub use datetime::register_datetime_builtins;
109pub use math::register_math_builtins;
110pub use scalar::ScalarFunction;
111pub use vtab::{
112 ColumnContext, ConstraintOp, IndexConstraint, IndexConstraintUsage, IndexInfo, IndexOrderBy,
113 VirtualTable, VirtualTableCursor,
114};
115pub use window::{WindowAdapter, WindowFunction};
116pub use window_builtins::register_window_builtins;
117
118pub type ErasedAggregateFunction = dyn AggregateFunction<State = Box<dyn Any + Send>>;
120
121pub type ErasedWindowFunction = dyn WindowFunction<State = Box<dyn Any + Send>>;
123
124#[derive(Debug, Clone, Hash, Eq, PartialEq)]
129pub struct FunctionKey {
130 pub name: String,
132 pub num_args: i32,
134}
135
136impl FunctionKey {
137 #[must_use]
139 pub fn new(name: &str, num_args: i32) -> Self {
140 Self {
141 name: canonical_name(name),
142 num_args,
143 }
144 }
145}
146
147#[derive(Default)]
155pub struct FunctionRegistry {
156 scalars: HashMap<FunctionKey, Arc<dyn ScalarFunction>>,
157 aggregates: HashMap<FunctionKey, Arc<ErasedAggregateFunction>>,
158 windows: HashMap<FunctionKey, Arc<ErasedWindowFunction>>,
159}
160
161impl FunctionRegistry {
162 #[must_use]
164 pub fn new() -> Self {
165 Self::default()
166 }
167
168 #[must_use]
173 pub fn clone_from_arc(arc: &Arc<Self>) -> Self {
174 Self {
175 scalars: arc.scalars.clone(),
176 aggregates: arc.aggregates.clone(),
177 windows: arc.windows.clone(),
178 }
179 }
180
181 pub fn register_scalar<F>(&mut self, function: F) -> Option<Arc<dyn ScalarFunction>>
186 where
187 F: ScalarFunction + 'static,
188 {
189 let key = FunctionKey::new(function.name(), function.num_args());
190 self.scalars.insert(key, Arc::new(function))
191 }
192
193 pub fn register_aggregate<F>(&mut self, function: F) -> Option<Arc<ErasedAggregateFunction>>
197 where
198 F: AggregateFunction + 'static,
199 F::State: 'static,
200 {
201 let key = FunctionKey::new(function.name(), function.num_args());
202 self.aggregates
203 .insert(key, Arc::new(AggregateAdapter::new(function)))
204 }
205
206 pub fn register_window<F>(&mut self, function: F) -> Option<Arc<ErasedWindowFunction>>
210 where
211 F: WindowFunction + 'static,
212 F::State: 'static,
213 {
214 let key = FunctionKey::new(function.name(), function.num_args());
215 self.windows
216 .insert(key, Arc::new(WindowAdapter::new(function)))
217 }
218
219 #[must_use]
224 pub fn find_scalar(&self, name: &str, num_args: i32) -> Option<Arc<dyn ScalarFunction>> {
225 let canon = canonical_name(name);
226 self.find_scalar_precanonical(&canon, num_args)
227 }
228
229 #[must_use]
234 pub fn find_scalar_precanonical(
235 &self,
236 canonical: &str,
237 num_args: i32,
238 ) -> Option<Arc<dyn ScalarFunction>> {
239 let exact = FunctionKey {
240 name: canonical.to_owned(),
241 num_args,
242 };
243 if let Some(f) = self.scalars.get(&exact) {
244 debug!(name = %canonical, arity = num_args, kind = "scalar", hit = "exact", "registry lookup");
245 return Some(Arc::clone(f));
246 }
247 let variadic = FunctionKey {
249 name: canonical.to_owned(),
250 num_args: -1,
251 };
252 let result = self.scalars.get(&variadic).map(Arc::clone);
253 debug!(
254 name = %canonical,
255 arity = num_args,
256 kind = "scalar",
257 hit = if result.is_some() { "variadic" } else { "miss" },
258 "registry lookup"
259 );
260 result
261 }
262
263 #[must_use]
267 pub fn find_aggregate(
268 &self,
269 name: &str,
270 num_args: i32,
271 ) -> Option<Arc<ErasedAggregateFunction>> {
272 let canon = canonical_name(name);
273 self.find_aggregate_precanonical(&canon, num_args)
274 }
275
276 #[must_use]
278 pub fn find_aggregate_precanonical(
279 &self,
280 canonical: &str,
281 num_args: i32,
282 ) -> Option<Arc<ErasedAggregateFunction>> {
283 let exact = FunctionKey {
284 name: canonical.to_owned(),
285 num_args,
286 };
287 if let Some(f) = self.aggregates.get(&exact) {
288 debug!(name = %canonical, arity = num_args, kind = "aggregate", hit = "exact", "registry lookup");
289 return Some(Arc::clone(f));
290 }
291 let variadic = FunctionKey {
292 name: canonical.to_owned(),
293 num_args: -1,
294 };
295 let result = self.aggregates.get(&variadic).map(Arc::clone);
296 debug!(
297 name = %canonical,
298 arity = num_args,
299 kind = "aggregate",
300 hit = if result.is_some() { "variadic" } else { "miss" },
301 "registry lookup"
302 );
303 result
304 }
305
306 #[must_use]
310 pub fn find_window(&self, name: &str, num_args: i32) -> Option<Arc<ErasedWindowFunction>> {
311 let canon = canonical_name(name);
312 let exact = FunctionKey {
313 name: canon.clone(),
314 num_args,
315 };
316 if let Some(f) = self.windows.get(&exact) {
317 debug!(name = %canon, arity = num_args, kind = "window", hit = "exact", "registry lookup");
318 return Some(Arc::clone(f));
319 }
320 let variadic = FunctionKey {
321 name: canon.clone(),
322 num_args: -1,
323 };
324 let result = self.windows.get(&variadic).map(Arc::clone);
325 debug!(
326 name = %canon,
327 arity = num_args,
328 kind = "window",
329 hit = if result.is_some() { "variadic" } else { "miss" },
330 "registry lookup"
331 );
332 result
333 }
334
335 #[must_use]
338 pub fn contains_scalar(&self, name: &str) -> bool {
339 let canon = canonical_name(name);
340 self.scalars.keys().any(|k| k.name == canon)
341 }
342
343 #[must_use]
346 pub fn contains_aggregate(&self, name: &str) -> bool {
347 let canon = canonical_name(name);
348 self.aggregates.keys().any(|k| k.name == canon)
349 }
350
351 #[must_use]
354 pub fn contains_window(&self, name: &str) -> bool {
355 let canon = canonical_name(name);
356 self.windows.keys().any(|k| k.name == canon)
357 }
358
359 #[must_use]
363 pub fn aggregate_names_lowercase(&self) -> Vec<String> {
364 let mut names: Vec<String> = self
365 .aggregates
366 .keys()
367 .map(|k| k.name.to_ascii_lowercase())
368 .collect();
369 names.sort();
370 names.dedup();
371 names
372 }
373}
374
375fn canonical_name(name: &str) -> String {
376 name.trim().to_ascii_uppercase()
377}
378
379#[cfg(test)]
380mod tests {
381 use fsqlite_types::SqliteValue;
382
383 use super::*;
384
385 struct Double;
388
389 impl ScalarFunction for Double {
390 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
391 Ok(SqliteValue::Integer(args[0].to_integer() * 2))
392 }
393
394 fn num_args(&self) -> i32 {
395 1
396 }
397
398 fn name(&self) -> &str {
399 "double"
400 }
401 }
402
403 struct VariadicConcat;
406
407 impl ScalarFunction for VariadicConcat {
408 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
409 let mut out = String::new();
410 for a in args {
411 out.push_str(&a.to_text());
412 }
413 Ok(SqliteValue::Text(out.into()))
414 }
415
416 fn num_args(&self) -> i32 {
417 -1
418 }
419
420 fn name(&self) -> &str {
421 "my_func"
422 }
423 }
424
425 struct TwoArgFunc;
428
429 impl ScalarFunction for TwoArgFunc {
430 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
431 Ok(SqliteValue::Integer(
432 args[0].to_integer() + args[1].to_integer(),
433 ))
434 }
435
436 fn num_args(&self) -> i32 {
437 2
438 }
439
440 fn name(&self) -> &str {
441 "my_func"
442 }
443 }
444
445 struct Product;
446
447 impl AggregateFunction for Product {
448 type State = i64;
449
450 fn initial_state(&self) -> Self::State {
451 1
452 }
453
454 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> fsqlite_error::Result<()> {
455 *state *= args[0].to_integer();
456 Ok(())
457 }
458
459 fn finalize(&self, state: Self::State) -> fsqlite_error::Result<SqliteValue> {
460 Ok(SqliteValue::Integer(state))
461 }
462
463 fn num_args(&self) -> i32 {
464 1
465 }
466
467 fn name(&self) -> &str {
468 "product"
469 }
470 }
471
472 struct MovingSum;
473
474 impl WindowFunction for MovingSum {
475 type State = i64;
476
477 fn initial_state(&self) -> Self::State {
478 0
479 }
480
481 fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> fsqlite_error::Result<()> {
482 *state += args[0].to_integer();
483 Ok(())
484 }
485
486 fn inverse(
487 &self,
488 state: &mut Self::State,
489 args: &[SqliteValue],
490 ) -> fsqlite_error::Result<()> {
491 *state -= args[0].to_integer();
492 Ok(())
493 }
494
495 fn value(&self, state: &Self::State) -> fsqlite_error::Result<SqliteValue> {
496 Ok(SqliteValue::Integer(*state))
497 }
498
499 fn finalize(&self, state: Self::State) -> fsqlite_error::Result<SqliteValue> {
500 Ok(SqliteValue::Integer(state))
501 }
502
503 fn num_args(&self) -> i32 {
504 1
505 }
506
507 fn name(&self) -> &str {
508 "moving_sum"
509 }
510 }
511
512 #[test]
513 fn test_registry_register_scalar() {
514 let mut registry = FunctionRegistry::new();
515 let previous = registry.register_scalar(Double);
516 assert!(previous.is_none());
517 assert!(registry.contains_scalar("double"));
518 assert!(registry.contains_scalar("DOUBLE"));
519 let f = registry
520 .find_scalar(" Double ", 1)
521 .expect("double registered");
522 assert_eq!(
523 f.invoke(&[SqliteValue::Integer(21)])
524 .expect("invoke succeeds"),
525 SqliteValue::Integer(42)
526 );
527 }
528
529 #[test]
530 fn test_registry_case_insensitive_lookup() {
531 let mut registry = FunctionRegistry::new();
532 registry.register_scalar(Double);
533
534 assert!(registry.find_scalar("DOUBLE", 1).is_some());
536 assert!(registry.find_scalar("Double", 1).is_some());
537 assert!(registry.find_scalar(" double ", 1).is_some());
538 }
539
540 #[test]
541 fn test_registry_overwrite() {
542 let mut registry = FunctionRegistry::new();
543
544 let prev = registry.register_scalar(Double);
546 assert!(prev.is_none());
547
548 let prev = registry.register_scalar(Double);
550 assert!(prev.is_some());
551
552 let f = registry.find_scalar("double", 1).unwrap();
554 assert_eq!(
555 f.invoke(&[SqliteValue::Integer(5)]).unwrap(),
556 SqliteValue::Integer(10)
557 );
558 }
559
560 #[test]
561 fn test_registry_variadic_fallback() {
562 let mut registry = FunctionRegistry::new();
563
564 registry.register_scalar(VariadicConcat);
566
567 let f = registry
569 .find_scalar("my_func", 3)
570 .expect("variadic fallback");
571 assert_eq!(
572 f.invoke(&[
573 SqliteValue::Text("a".into()),
574 SqliteValue::Text("b".into()),
575 SqliteValue::Text("c".into()),
576 ])
577 .unwrap(),
578 SqliteValue::Text("abc".into())
579 );
580 }
581
582 #[test]
583 fn test_registry_exact_match_over_variadic() {
584 let mut registry = FunctionRegistry::new();
585
586 registry.register_scalar(VariadicConcat);
588 registry.register_scalar(TwoArgFunc);
589
590 let f = registry
592 .find_scalar("my_func", 2)
593 .expect("exact match found");
594 assert_eq!(
595 f.invoke(&[SqliteValue::Integer(10), SqliteValue::Integer(32)])
596 .unwrap(),
597 SqliteValue::Integer(42)
598 );
599
600 let f = registry
602 .find_scalar("my_func", 5)
603 .expect("variadic fallback");
604 assert_eq!(f.num_args(), -1);
605 }
606
607 #[test]
608 fn test_registry_not_found_returns_none() {
609 let registry = FunctionRegistry::new();
610 assert!(registry.find_scalar("nonexistent", 1).is_none());
611 assert!(registry.find_aggregate("nonexistent", 1).is_none());
612 assert!(registry.find_window("nonexistent", 1).is_none());
613 }
614
615 #[test]
616 fn test_registry_register_and_resolve_aggregate() {
617 let mut registry = FunctionRegistry::new();
618 let previous = registry.register_aggregate(Product);
619 assert!(previous.is_none());
620 assert!(registry.contains_aggregate("product"));
621 let f = registry
622 .find_aggregate("PRODUCT", 1)
623 .expect("product aggregate registered");
624
625 let mut state = f.initial_state();
626 f.step(&mut state, &[SqliteValue::Integer(2)])
627 .expect("step 1");
628 f.step(&mut state, &[SqliteValue::Integer(3)])
629 .expect("step 2");
630 f.step(&mut state, &[SqliteValue::Integer(7)])
631 .expect("step 3");
632
633 assert_eq!(
634 f.finalize(state).expect("finalize succeeds"),
635 SqliteValue::Integer(42)
636 );
637 }
638
639 #[test]
640 fn test_registry_aggregate_type_erased() {
641 let mut registry = FunctionRegistry::new();
642 registry.register_aggregate(Product);
643
644 let f = registry
646 .find_aggregate("product", 1)
647 .expect("product found");
648 let mut state = f.initial_state();
649 f.step(&mut state, &[SqliteValue::Integer(6)]).unwrap();
650 f.step(&mut state, &[SqliteValue::Integer(7)]).unwrap();
651 assert_eq!(f.finalize(state).unwrap(), SqliteValue::Integer(42));
652 assert_eq!(f.name(), "product");
653 }
654
655 #[test]
656 fn test_registry_register_and_resolve_window() {
657 let mut registry = FunctionRegistry::new();
658 let previous = registry.register_window(MovingSum);
659 assert!(previous.is_none());
660 assert!(registry.contains_window("moving_sum"));
661 let f = registry
662 .find_window("MOVING_SUM", 1)
663 .expect("moving_sum window registered");
664
665 let mut state = f.initial_state();
666 f.step(&mut state, &[SqliteValue::Integer(10)])
667 .expect("step 1");
668 f.step(&mut state, &[SqliteValue::Integer(20)])
669 .expect("step 2");
670 f.step(&mut state, &[SqliteValue::Integer(30)])
671 .expect("step 3");
672 assert_eq!(f.value(&state).expect("value"), SqliteValue::Integer(60));
673
674 f.inverse(&mut state, &[SqliteValue::Integer(10)])
675 .expect("inverse 1");
676 f.step(&mut state, &[SqliteValue::Integer(40)])
677 .expect("step 4");
678 assert_eq!(f.value(&state).expect("value"), SqliteValue::Integer(90));
679 }
680
681 #[test]
682 fn test_registry_window_type_erased() {
683 let mut registry = FunctionRegistry::new();
684 registry.register_window(MovingSum);
685
686 let f = registry
687 .find_window("moving_sum", 1)
688 .expect("moving_sum found");
689
690 let mut state = f.initial_state();
692 f.step(&mut state, &[SqliteValue::Integer(100)]).unwrap();
693 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(100));
694
695 f.inverse(&mut state, &[SqliteValue::Integer(100)]).unwrap();
696 assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(0));
697
698 f.step(&mut state, &[SqliteValue::Integer(42)]).unwrap();
699 assert_eq!(f.finalize(state).unwrap(), SqliteValue::Integer(42));
700 }
701
702 #[test]
703 fn test_function_key_equality() {
704 let k1 = FunctionKey::new("ABS", 1);
705 let k2 = FunctionKey::new("abs", 1);
706 let k3 = FunctionKey::new("ABS", 2);
707
708 assert_eq!(k1, k2, "case-insensitive equality");
709 assert_ne!(k1, k3, "different num_args");
710 }
711
712 #[test]
715 fn test_e2e_custom_collation_in_order_by() {
716 use collation::{BinaryCollation, CollationFunction, NoCaseCollation, RtrimCollation};
717
718 struct ReverseAlpha;
720
721 impl CollationFunction for ReverseAlpha {
722 fn name(&self) -> &str {
723 "REVERSE_ALPHA"
724 }
725
726 fn compare(&self, left: &[u8], right: &[u8]) -> std::cmp::Ordering {
727 right.cmp(left)
729 }
730 }
731
732 let coll = ReverseAlpha;
733 let mut data: Vec<&[u8]> = vec![b"banana", b"apple", b"cherry", b"date"];
734 data.sort_by(|a, b| coll.compare(a, b));
735
736 let expected: Vec<&[u8]> = vec![b"date", b"cherry", b"banana", b"apple"];
738 assert_eq!(data, expected);
739 assert_eq!(coll.name(), "REVERSE_ALPHA");
740
741 let collations: Vec<Box<dyn CollationFunction>> = vec![
743 Box::new(BinaryCollation),
744 Box::new(NoCaseCollation),
745 Box::new(RtrimCollation),
746 Box::new(ReverseAlpha),
747 ];
748 assert_eq!(collations.len(), 4);
749
750 let mut binary_sorted = data.clone();
752 binary_sorted.sort_by(|a, b| collations[0].compare(a, b));
753 assert_eq!(binary_sorted[0], b"apple");
754 }
755
756 #[test]
757 fn test_e2e_authorizer_sandboxing() {
758 use authorizer::{AuthAction, AuthResult, Authorizer};
759
760 struct SelectOnlyAuthorizer;
762
763 impl Authorizer for SelectOnlyAuthorizer {
764 fn authorize(
765 &self,
766 action: AuthAction,
767 _arg1: Option<&str>,
768 arg2: Option<&str>,
769 _db_name: Option<&str>,
770 _trigger: Option<&str>,
771 ) -> AuthResult {
772 match action {
773 AuthAction::Select | AuthAction::Read => {
774 if action == AuthAction::Read && arg2 == Some("secret") {
776 return AuthResult::Ignore;
777 }
778 AuthResult::Ok
779 }
780 AuthAction::Insert | AuthAction::Update | AuthAction::Delete => {
781 AuthResult::Deny
782 }
783 _ => AuthResult::Ok,
784 }
785 }
786 }
787
788 let auth = SelectOnlyAuthorizer;
789
790 assert_eq!(
792 auth.authorize(AuthAction::Select, None, None, Some("main"), None),
793 AuthResult::Ok,
794 "SELECT must be allowed"
795 );
796
797 assert_eq!(
799 auth.authorize(AuthAction::Insert, Some("users"), None, Some("main"), None),
800 AuthResult::Deny,
801 "INSERT must be denied (compile-time auth error)"
802 );
803
804 assert_eq!(
806 auth.authorize(
807 AuthAction::Update,
808 Some("users"),
809 Some("email"),
810 Some("main"),
811 None
812 ),
813 AuthResult::Deny,
814 );
815
816 assert_eq!(
818 auth.authorize(AuthAction::Delete, Some("users"), None, Some("main"), None),
819 AuthResult::Deny,
820 );
821
822 assert_eq!(
824 auth.authorize(
825 AuthAction::Read,
826 Some("users"),
827 Some("secret"),
828 Some("main"),
829 None
830 ),
831 AuthResult::Ignore,
832 "Ignore must nullify column"
833 );
834
835 assert_eq!(
837 auth.authorize(
838 AuthAction::Read,
839 Some("users"),
840 Some("name"),
841 Some("main"),
842 None
843 ),
844 AuthResult::Ok,
845 );
846 }
847
848 #[test]
849 fn test_e2e_function_registry_resolution() {
850 struct Abs1;
852
853 impl ScalarFunction for Abs1 {
854 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
855 Ok(SqliteValue::Integer(args[0].to_integer().abs()))
856 }
857
858 fn num_args(&self) -> i32 {
859 1
860 }
861
862 fn name(&self) -> &str {
863 "abs"
864 }
865 }
866
867 struct AbsVariadic;
868
869 impl ScalarFunction for AbsVariadic {
870 fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
871 let sum: i64 = args.iter().map(|a| a.to_integer().abs()).sum();
873 Ok(SqliteValue::Integer(sum))
874 }
875
876 fn num_args(&self) -> i32 {
877 -1
878 }
879
880 fn name(&self) -> &str {
881 "abs"
882 }
883 }
884
885 let mut registry = FunctionRegistry::new();
886 registry.register_scalar(Abs1);
887 registry.register_scalar(AbsVariadic);
888
889 let f = registry.find_scalar("abs", 1).expect("abs(1) found");
891 assert_eq!(f.num_args(), 1, "exact 1-arg match");
892 assert_eq!(
893 f.invoke(&[SqliteValue::Integer(-5)]).unwrap(),
894 SqliteValue::Integer(5)
895 );
896
897 let f = registry.find_scalar("abs", 2).expect("abs variadic found");
899 assert_eq!(f.num_args(), -1, "variadic fallback for 2 args");
900 assert_eq!(
901 f.invoke(&[SqliteValue::Integer(-5), SqliteValue::Integer(-3)])
902 .unwrap(),
903 SqliteValue::Integer(8)
904 );
905
906 assert!(registry.find_scalar("nonexistent", 1).is_none());
908 }
909
910 #[test]
911 fn test_authorizer_called_at_compile_time() {
912 use authorizer::{AuthAction, AuthResult, Authorizer};
913 use std::sync::Mutex;
914
915 struct TrackingAuthorizer {
917 calls: Mutex<Vec<AuthAction>>,
918 }
919
920 impl TrackingAuthorizer {
921 fn new() -> Self {
922 Self {
923 calls: Mutex::new(Vec::new()),
924 }
925 }
926 }
927
928 impl Authorizer for TrackingAuthorizer {
929 fn authorize(
930 &self,
931 action: AuthAction,
932 _arg1: Option<&str>,
933 _arg2: Option<&str>,
934 _db_name: Option<&str>,
935 _trigger: Option<&str>,
936 ) -> AuthResult {
937 self.calls.lock().unwrap().push(action);
938 AuthResult::Ok
939 }
940 }
941
942 let auth = TrackingAuthorizer::new();
943
944 auth.authorize(AuthAction::Select, None, None, Some("main"), None);
956 auth.authorize(
957 AuthAction::Read,
958 Some("users"),
959 Some("name"),
960 Some("main"),
961 None,
962 );
963 auth.authorize(
964 AuthAction::Read,
965 Some("users"),
966 Some("email"),
967 Some("main"),
968 None,
969 );
970 auth.authorize(
971 AuthAction::Read,
972 Some("users"),
973 Some("id"),
974 Some("main"),
975 None,
976 );
977
978 let calls = auth.calls.lock().unwrap();
979 assert_eq!(calls.len(), 4, "authorizer called 4 times during prepare");
980 assert_eq!(calls[0], AuthAction::Select);
981 assert_eq!(calls[1], AuthAction::Read);
982 assert_eq!(calls[2], AuthAction::Read);
983 assert_eq!(calls[3], AuthAction::Read);
984 drop(calls);
985
986 let calls_after = auth.calls.lock().unwrap();
990 assert_eq!(
991 calls_after.len(),
992 4,
993 "authorizer must NOT be called during step/execution"
994 );
995 drop(calls_after);
996 }
997}