Skip to main content

fsqlite_func/
lib.rs

1//! Built-in SQL function and extension trait surfaces.
2//!
3//! This crate defines open, user-implementable traits for:
4//! - scalar, aggregate, and window functions
5//! - virtual table modules/cursors
6//! - collation callbacks
7//! - authorizer callbacks
8//!
9//! It also provides a small in-memory [`FunctionRegistry`] for registering and
10//! resolving scalar/aggregate/window functions by `(name, num_args)` key with
11//! variadic fallback.
12#![allow(clippy::unnecessary_literal_bound)]
13
14use std::any::Any;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU64, Ordering};
18
19use tracing::debug;
20
21// ── Function evaluation metrics (bd-2wt.1) ─────────────────────────────────
22
23/// Total number of scalar function calls across all statements.
24static FSQLITE_FUNC_CALLS_TOTAL: AtomicU64 = AtomicU64::new(0);
25/// Cumulative function evaluation duration in microseconds.
26static FSQLITE_FUNC_EVAL_DURATION_US_TOTAL: AtomicU64 = AtomicU64::new(0);
27
28/// Snapshot of function evaluation metrics.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct FuncMetricsSnapshot {
31    /// Total scalar function calls.
32    pub calls_total: u64,
33    /// Cumulative evaluation duration in microseconds.
34    pub eval_duration_us_total: u64,
35}
36
37/// Read a point-in-time snapshot of function evaluation metrics.
38#[must_use]
39pub fn func_metrics_snapshot() -> FuncMetricsSnapshot {
40    FuncMetricsSnapshot {
41        calls_total: FSQLITE_FUNC_CALLS_TOTAL.load(Ordering::Relaxed),
42        eval_duration_us_total: FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.load(Ordering::Relaxed),
43    }
44}
45
46/// Reset function metrics to zero (tests/diagnostics).
47pub fn reset_func_metrics() {
48    FSQLITE_FUNC_CALLS_TOTAL.store(0, Ordering::Relaxed);
49    FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.store(0, Ordering::Relaxed);
50}
51
52/// Record a function call for metrics (called from VDBE engine).
53pub fn record_func_call(duration_us: u64) {
54    FSQLITE_FUNC_CALLS_TOTAL.fetch_add(1, Ordering::Relaxed);
55    FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.fetch_add(duration_us, Ordering::Relaxed);
56}
57
58// ── UDF registration metrics (bd-2wt.3) ────────────────────────────────
59
60/// Total number of UDF registrations.
61static FSQLITE_UDF_REGISTERED: AtomicU64 = AtomicU64::new(0);
62
63/// Record a UDF registration event.
64pub fn record_udf_registered() {
65    FSQLITE_UDF_REGISTERED.fetch_add(1, Ordering::Relaxed);
66}
67
68/// Current count of UDF registrations.
69#[must_use]
70pub fn udf_registered_count() -> u64 {
71    FSQLITE_UDF_REGISTERED.load(Ordering::Relaxed)
72}
73
74/// Reset UDF registration counter (tests/diagnostics).
75pub fn reset_udf_metrics() {
76    FSQLITE_UDF_REGISTERED.store(0, Ordering::Relaxed);
77}
78
79pub mod agg_builtins;
80pub mod aggregate;
81pub mod authorizer;
82pub mod builtins;
83pub mod collation;
84pub mod datetime;
85pub mod math;
86pub mod scalar;
87pub mod vtab;
88pub mod window;
89pub mod window_builtins;
90
91pub use agg_builtins::register_aggregate_builtins;
92pub use aggregate::{AggregateAdapter, AggregateFunction};
93pub use authorizer::{AuthAction, AuthResult, Authorizer, AuthorizerAction, AuthorizerDecision};
94pub use builtins::{
95    get_last_changes, get_last_insert_rowid, register_builtins, set_last_changes,
96    set_last_insert_rowid,
97};
98pub use collation::{
99    BinaryCollation, CollationAnnotation, CollationFunction, CollationRegistry, CollationSource,
100    NoCaseCollation, RtrimCollation, resolve_collation,
101};
102pub use datetime::register_datetime_builtins;
103pub use math::register_math_builtins;
104pub use scalar::ScalarFunction;
105pub use vtab::{
106    ColumnContext, ConstraintOp, IndexConstraint, IndexConstraintUsage, IndexInfo, IndexOrderBy,
107    VirtualTable, VirtualTableCursor,
108};
109pub use window::{WindowAdapter, WindowFunction};
110pub use window_builtins::register_window_builtins;
111
112/// Type-erased aggregate function object used by the registry.
113pub type ErasedAggregateFunction = dyn AggregateFunction<State = Box<dyn Any + Send>>;
114
115/// Type-erased window function object used by the registry.
116pub type ErasedWindowFunction = dyn WindowFunction<State = Box<dyn Any + Send>>;
117
118/// Composite lookup key for functions: `(UPPERCASE name, num_args)`.
119///
120/// `-1` for `num_args` means variadic (any number of arguments).
121/// Names are stored as uppercase ASCII for case-insensitive matching.
122#[derive(Debug, Clone, Hash, Eq, PartialEq)]
123pub struct FunctionKey {
124    /// Function name, stored as uppercase ASCII.
125    pub name: String,
126    /// Expected argument count, or `-1` for variadic.
127    pub num_args: i32,
128}
129
130impl FunctionKey {
131    /// Create a new function key with the name canonicalized to uppercase.
132    #[must_use]
133    pub fn new(name: &str, num_args: i32) -> Self {
134        Self {
135            name: canonical_name(name),
136            num_args,
137        }
138    }
139}
140
141/// Registry for scalar, aggregate, and window functions, keyed by
142/// `(name, num_args)`.
143///
144/// Lookup strategy (§9.5):
145/// 1. Exact match on `(UPPERCASE_NAME, num_args)`.
146/// 2. Fallback to variadic version `(UPPERCASE_NAME, -1)`.
147/// 3. `None` if neither found (caller should raise "no such function").
148#[derive(Default)]
149pub struct FunctionRegistry {
150    scalars: HashMap<FunctionKey, Arc<dyn ScalarFunction>>,
151    aggregates: HashMap<FunctionKey, Arc<ErasedAggregateFunction>>,
152    windows: HashMap<FunctionKey, Arc<ErasedWindowFunction>>,
153}
154
155impl FunctionRegistry {
156    /// Create an empty registry.
157    #[must_use]
158    pub fn new() -> Self {
159        Self::default()
160    }
161
162    /// Create a mutable clone of a registry from an `Arc` reference.
163    ///
164    /// This is used by the UDF registration API to produce a new registry
165    /// containing the existing functions plus the newly registered UDF.
166    #[must_use]
167    pub fn clone_from_arc(arc: &Arc<Self>) -> Self {
168        Self {
169            scalars: arc.scalars.clone(),
170            aggregates: arc.aggregates.clone(),
171            windows: arc.windows.clone(),
172        }
173    }
174
175    /// Register a scalar function, keyed by `(name, num_args)`.
176    ///
177    /// Overwrites any existing function with the same key. Returns the
178    /// previous function if one existed.
179    pub fn register_scalar<F>(&mut self, function: F) -> Option<Arc<dyn ScalarFunction>>
180    where
181        F: ScalarFunction + 'static,
182    {
183        let key = FunctionKey::new(function.name(), function.num_args());
184        self.scalars.insert(key, Arc::new(function))
185    }
186
187    /// Register an aggregate function using the type-erased adapter.
188    ///
189    /// Overwrites any existing function with the same `(name, num_args)` key.
190    pub fn register_aggregate<F>(&mut self, function: F) -> Option<Arc<ErasedAggregateFunction>>
191    where
192        F: AggregateFunction + 'static,
193        F::State: 'static,
194    {
195        let key = FunctionKey::new(function.name(), function.num_args());
196        self.aggregates
197            .insert(key, Arc::new(AggregateAdapter::new(function)))
198    }
199
200    /// Register a window function using the type-erased adapter.
201    ///
202    /// Overwrites any existing function with the same `(name, num_args)` key.
203    pub fn register_window<F>(&mut self, function: F) -> Option<Arc<ErasedWindowFunction>>
204    where
205        F: WindowFunction + 'static,
206        F::State: 'static,
207    {
208        let key = FunctionKey::new(function.name(), function.num_args());
209        self.windows
210            .insert(key, Arc::new(WindowAdapter::new(function)))
211    }
212
213    /// Look up a scalar function by `(name, num_args)`.
214    ///
215    /// Tries exact match first, then falls back to the variadic version
216    /// `(name, -1)` if no exact match exists.
217    #[must_use]
218    pub fn find_scalar(&self, name: &str, num_args: i32) -> Option<Arc<dyn ScalarFunction>> {
219        let canon = canonical_name(name);
220        let exact = FunctionKey {
221            name: canon.clone(),
222            num_args,
223        };
224        if let Some(f) = self.scalars.get(&exact) {
225            debug!(name = %canon, arity = num_args, kind = "scalar", hit = "exact", "registry lookup");
226            return Some(Arc::clone(f));
227        }
228        // Variadic fallback
229        let variadic = FunctionKey {
230            name: canon.clone(),
231            num_args: -1,
232        };
233        let result = self.scalars.get(&variadic).map(Arc::clone);
234        debug!(
235            name = %canon,
236            arity = num_args,
237            kind = "scalar",
238            hit = if result.is_some() { "variadic" } else { "miss" },
239            "registry lookup"
240        );
241        result
242    }
243
244    /// Look up an aggregate function by `(name, num_args)`.
245    ///
246    /// Tries exact match first, then falls back to variadic `(name, -1)`.
247    #[must_use]
248    pub fn find_aggregate(
249        &self,
250        name: &str,
251        num_args: i32,
252    ) -> Option<Arc<ErasedAggregateFunction>> {
253        let canon = canonical_name(name);
254        let exact = FunctionKey {
255            name: canon.clone(),
256            num_args,
257        };
258        if let Some(f) = self.aggregates.get(&exact) {
259            debug!(name = %canon, arity = num_args, kind = "aggregate", hit = "exact", "registry lookup");
260            return Some(Arc::clone(f));
261        }
262        let variadic = FunctionKey {
263            name: canon.clone(),
264            num_args: -1,
265        };
266        let result = self.aggregates.get(&variadic).map(Arc::clone);
267        debug!(
268            name = %canon,
269            arity = num_args,
270            kind = "aggregate",
271            hit = if result.is_some() { "variadic" } else { "miss" },
272            "registry lookup"
273        );
274        result
275    }
276
277    /// Look up a window function by `(name, num_args)`.
278    ///
279    /// Tries exact match first, then falls back to variadic `(name, -1)`.
280    #[must_use]
281    pub fn find_window(&self, name: &str, num_args: i32) -> Option<Arc<ErasedWindowFunction>> {
282        let canon = canonical_name(name);
283        let exact = FunctionKey {
284            name: canon.clone(),
285            num_args,
286        };
287        if let Some(f) = self.windows.get(&exact) {
288            debug!(name = %canon, arity = num_args, kind = "window", hit = "exact", "registry lookup");
289            return Some(Arc::clone(f));
290        }
291        let variadic = FunctionKey {
292            name: canon.clone(),
293            num_args: -1,
294        };
295        let result = self.windows.get(&variadic).map(Arc::clone);
296        debug!(
297            name = %canon,
298            arity = num_args,
299            kind = "window",
300            hit = if result.is_some() { "variadic" } else { "miss" },
301            "registry lookup"
302        );
303        result
304    }
305
306    /// Whether the registry contains any scalar function with this name
307    /// (any arg count).
308    #[must_use]
309    pub fn contains_scalar(&self, name: &str) -> bool {
310        let canon = canonical_name(name);
311        self.scalars.keys().any(|k| k.name == canon)
312    }
313
314    /// Whether the registry contains any aggregate function with this name
315    /// (any arg count).
316    #[must_use]
317    pub fn contains_aggregate(&self, name: &str) -> bool {
318        let canon = canonical_name(name);
319        self.aggregates.keys().any(|k| k.name == canon)
320    }
321
322    /// Whether the registry contains any window function with this name
323    /// (any arg count).
324    #[must_use]
325    pub fn contains_window(&self, name: &str) -> bool {
326        let canon = canonical_name(name);
327        self.windows.keys().any(|k| k.name == canon)
328    }
329
330    /// Return deduplicated lowercase names of all registered aggregate functions.
331    ///
332    /// Used by the codegen thread-local to recognize custom aggregate UDFs.
333    #[must_use]
334    pub fn aggregate_names_lowercase(&self) -> Vec<String> {
335        let mut names: Vec<String> = self
336            .aggregates
337            .keys()
338            .map(|k| k.name.to_ascii_lowercase())
339            .collect();
340        names.sort();
341        names.dedup();
342        names
343    }
344}
345
346fn canonical_name(name: &str) -> String {
347    name.trim().to_ascii_uppercase()
348}
349
350#[cfg(test)]
351mod tests {
352    use fsqlite_types::SqliteValue;
353
354    use super::*;
355
356    // -- Mock: double(x) -> x * 2, fixed 1-arg --
357
358    struct Double;
359
360    impl ScalarFunction for Double {
361        fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
362            Ok(SqliteValue::Integer(args[0].to_integer() * 2))
363        }
364
365        fn num_args(&self) -> i32 {
366            1
367        }
368
369        fn name(&self) -> &str {
370            "double"
371        }
372    }
373
374    // -- Mock: variadic concat --
375
376    struct VariadicConcat;
377
378    impl ScalarFunction for VariadicConcat {
379        fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
380            let mut out = String::new();
381            for a in args {
382                out.push_str(&a.to_text());
383            }
384            Ok(SqliteValue::Text(out))
385        }
386
387        fn num_args(&self) -> i32 {
388            -1
389        }
390
391        fn name(&self) -> &str {
392            "my_func"
393        }
394    }
395
396    // -- Mock: fixed 2-arg version of same name --
397
398    struct TwoArgFunc;
399
400    impl ScalarFunction for TwoArgFunc {
401        fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
402            Ok(SqliteValue::Integer(
403                args[0].to_integer() + args[1].to_integer(),
404            ))
405        }
406
407        fn num_args(&self) -> i32 {
408            2
409        }
410
411        fn name(&self) -> &str {
412            "my_func"
413        }
414    }
415
416    struct Product;
417
418    impl AggregateFunction for Product {
419        type State = i64;
420
421        fn initial_state(&self) -> Self::State {
422            1
423        }
424
425        fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> fsqlite_error::Result<()> {
426            *state *= args[0].to_integer();
427            Ok(())
428        }
429
430        fn finalize(&self, state: Self::State) -> fsqlite_error::Result<SqliteValue> {
431            Ok(SqliteValue::Integer(state))
432        }
433
434        fn num_args(&self) -> i32 {
435            1
436        }
437
438        fn name(&self) -> &str {
439            "product"
440        }
441    }
442
443    struct MovingSum;
444
445    impl WindowFunction for MovingSum {
446        type State = i64;
447
448        fn initial_state(&self) -> Self::State {
449            0
450        }
451
452        fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> fsqlite_error::Result<()> {
453            *state += args[0].to_integer();
454            Ok(())
455        }
456
457        fn inverse(
458            &self,
459            state: &mut Self::State,
460            args: &[SqliteValue],
461        ) -> fsqlite_error::Result<()> {
462            *state -= args[0].to_integer();
463            Ok(())
464        }
465
466        fn value(&self, state: &Self::State) -> fsqlite_error::Result<SqliteValue> {
467            Ok(SqliteValue::Integer(*state))
468        }
469
470        fn finalize(&self, state: Self::State) -> fsqlite_error::Result<SqliteValue> {
471            Ok(SqliteValue::Integer(state))
472        }
473
474        fn num_args(&self) -> i32 {
475            1
476        }
477
478        fn name(&self) -> &str {
479            "moving_sum"
480        }
481    }
482
483    #[test]
484    fn test_registry_register_scalar() {
485        let mut registry = FunctionRegistry::new();
486        let previous = registry.register_scalar(Double);
487        assert!(previous.is_none());
488        assert!(registry.contains_scalar("double"));
489        assert!(registry.contains_scalar("DOUBLE"));
490        let f = registry
491            .find_scalar(" Double ", 1)
492            .expect("double registered");
493        assert_eq!(
494            f.invoke(&[SqliteValue::Integer(21)])
495                .expect("invoke succeeds"),
496            SqliteValue::Integer(42)
497        );
498    }
499
500    #[test]
501    fn test_registry_case_insensitive_lookup() {
502        let mut registry = FunctionRegistry::new();
503        registry.register_scalar(Double);
504
505        // Register as "double", look up as "DOUBLE", "Double", " double "
506        assert!(registry.find_scalar("DOUBLE", 1).is_some());
507        assert!(registry.find_scalar("Double", 1).is_some());
508        assert!(registry.find_scalar(" double ", 1).is_some());
509    }
510
511    #[test]
512    fn test_registry_overwrite() {
513        let mut registry = FunctionRegistry::new();
514
515        // Register first version
516        let prev = registry.register_scalar(Double);
517        assert!(prev.is_none());
518
519        // Register second version with same (name, num_args) — overwrites
520        let prev = registry.register_scalar(Double);
521        assert!(prev.is_some());
522
523        // Still works
524        let f = registry.find_scalar("double", 1).unwrap();
525        assert_eq!(
526            f.invoke(&[SqliteValue::Integer(5)]).unwrap(),
527            SqliteValue::Integer(10)
528        );
529    }
530
531    #[test]
532    fn test_registry_variadic_fallback() {
533        let mut registry = FunctionRegistry::new();
534
535        // Register only the variadic version (num_args = -1)
536        registry.register_scalar(VariadicConcat);
537
538        // Look up with specific arg count — no exact match, falls back to variadic
539        let f = registry
540            .find_scalar("my_func", 3)
541            .expect("variadic fallback");
542        assert_eq!(
543            f.invoke(&[
544                SqliteValue::Text("a".to_owned()),
545                SqliteValue::Text("b".to_owned()),
546                SqliteValue::Text("c".to_owned()),
547            ])
548            .unwrap(),
549            SqliteValue::Text("abc".to_owned())
550        );
551    }
552
553    #[test]
554    fn test_registry_exact_match_over_variadic() {
555        let mut registry = FunctionRegistry::new();
556
557        // Register both variadic (num_args=-1) and exact 2-arg version
558        registry.register_scalar(VariadicConcat);
559        registry.register_scalar(TwoArgFunc);
560
561        // Look up with num_args=2 — exact match wins over variadic
562        let f = registry
563            .find_scalar("my_func", 2)
564            .expect("exact match found");
565        assert_eq!(
566            f.invoke(&[SqliteValue::Integer(10), SqliteValue::Integer(32)])
567                .unwrap(),
568            SqliteValue::Integer(42)
569        );
570
571        // Look up with num_args=5 — no exact match, falls back to variadic
572        let f = registry
573            .find_scalar("my_func", 5)
574            .expect("variadic fallback");
575        assert_eq!(f.num_args(), -1);
576    }
577
578    #[test]
579    fn test_registry_not_found_returns_none() {
580        let registry = FunctionRegistry::new();
581        assert!(registry.find_scalar("nonexistent", 1).is_none());
582        assert!(registry.find_aggregate("nonexistent", 1).is_none());
583        assert!(registry.find_window("nonexistent", 1).is_none());
584    }
585
586    #[test]
587    fn test_registry_register_and_resolve_aggregate() {
588        let mut registry = FunctionRegistry::new();
589        let previous = registry.register_aggregate(Product);
590        assert!(previous.is_none());
591        assert!(registry.contains_aggregate("product"));
592        let f = registry
593            .find_aggregate("PRODUCT", 1)
594            .expect("product aggregate registered");
595
596        let mut state = f.initial_state();
597        f.step(&mut state, &[SqliteValue::Integer(2)])
598            .expect("step 1");
599        f.step(&mut state, &[SqliteValue::Integer(3)])
600            .expect("step 2");
601        f.step(&mut state, &[SqliteValue::Integer(7)])
602            .expect("step 3");
603
604        assert_eq!(
605            f.finalize(state).expect("finalize succeeds"),
606            SqliteValue::Integer(42)
607        );
608    }
609
610    #[test]
611    fn test_registry_aggregate_type_erased() {
612        let mut registry = FunctionRegistry::new();
613        registry.register_aggregate(Product);
614
615        // Round-trip through type-erased registry
616        let f = registry
617            .find_aggregate("product", 1)
618            .expect("product found");
619        let mut state = f.initial_state();
620        f.step(&mut state, &[SqliteValue::Integer(6)]).unwrap();
621        f.step(&mut state, &[SqliteValue::Integer(7)]).unwrap();
622        assert_eq!(f.finalize(state).unwrap(), SqliteValue::Integer(42));
623        assert_eq!(f.name(), "product");
624    }
625
626    #[test]
627    fn test_registry_register_and_resolve_window() {
628        let mut registry = FunctionRegistry::new();
629        let previous = registry.register_window(MovingSum);
630        assert!(previous.is_none());
631        assert!(registry.contains_window("moving_sum"));
632        let f = registry
633            .find_window("MOVING_SUM", 1)
634            .expect("moving_sum window registered");
635
636        let mut state = f.initial_state();
637        f.step(&mut state, &[SqliteValue::Integer(10)])
638            .expect("step 1");
639        f.step(&mut state, &[SqliteValue::Integer(20)])
640            .expect("step 2");
641        f.step(&mut state, &[SqliteValue::Integer(30)])
642            .expect("step 3");
643        assert_eq!(f.value(&state).expect("value"), SqliteValue::Integer(60));
644
645        f.inverse(&mut state, &[SqliteValue::Integer(10)])
646            .expect("inverse 1");
647        f.step(&mut state, &[SqliteValue::Integer(40)])
648            .expect("step 4");
649        assert_eq!(f.value(&state).expect("value"), SqliteValue::Integer(90));
650    }
651
652    #[test]
653    fn test_registry_window_type_erased() {
654        let mut registry = FunctionRegistry::new();
655        registry.register_window(MovingSum);
656
657        let f = registry
658            .find_window("moving_sum", 1)
659            .expect("moving_sum found");
660
661        // Full lifecycle: initial_state -> step -> inverse -> value -> finalize
662        let mut state = f.initial_state();
663        f.step(&mut state, &[SqliteValue::Integer(100)]).unwrap();
664        assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(100));
665
666        f.inverse(&mut state, &[SqliteValue::Integer(100)]).unwrap();
667        assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(0));
668
669        f.step(&mut state, &[SqliteValue::Integer(42)]).unwrap();
670        assert_eq!(f.finalize(state).unwrap(), SqliteValue::Integer(42));
671    }
672
673    #[test]
674    fn test_function_key_equality() {
675        let k1 = FunctionKey::new("ABS", 1);
676        let k2 = FunctionKey::new("abs", 1);
677        let k3 = FunctionKey::new("ABS", 2);
678
679        assert_eq!(k1, k2, "case-insensitive equality");
680        assert_ne!(k1, k3, "different num_args");
681    }
682
683    // ── E2E: bd-1dc9 ────────────────────────────────────────────────────
684
685    #[test]
686    fn test_e2e_custom_collation_in_order_by() {
687        use collation::{BinaryCollation, CollationFunction, NoCaseCollation, RtrimCollation};
688
689        // Simulate ORDER BY with a custom reverse-alphabetical collation.
690        struct ReverseAlpha;
691
692        impl CollationFunction for ReverseAlpha {
693            fn name(&self) -> &str {
694                "REVERSE_ALPHA"
695            }
696
697            fn compare(&self, left: &[u8], right: &[u8]) -> std::cmp::Ordering {
698                // Reverse of BINARY
699                right.cmp(left)
700            }
701        }
702
703        let coll = ReverseAlpha;
704        let mut data: Vec<&[u8]> = vec![b"banana", b"apple", b"cherry", b"date"];
705        data.sort_by(|a, b| coll.compare(a, b));
706
707        // Reverse alphabetical: date > cherry > banana > apple
708        let expected: Vec<&[u8]> = vec![b"date", b"cherry", b"banana", b"apple"];
709        assert_eq!(data, expected);
710        assert_eq!(coll.name(), "REVERSE_ALPHA");
711
712        // Verify built-in collations are usable as trait objects.
713        let collations: Vec<Box<dyn CollationFunction>> = vec![
714            Box::new(BinaryCollation),
715            Box::new(NoCaseCollation),
716            Box::new(RtrimCollation),
717            Box::new(ReverseAlpha),
718        ];
719        assert_eq!(collations.len(), 4);
720
721        // Sort with BINARY: normal alphabetical
722        let mut binary_sorted = data.clone();
723        binary_sorted.sort_by(|a, b| collations[0].compare(a, b));
724        assert_eq!(binary_sorted[0], b"apple");
725    }
726
727    #[test]
728    fn test_e2e_authorizer_sandboxing() {
729        use authorizer::{AuthAction, AuthResult, Authorizer};
730
731        // Authorizer that denies INSERT/UPDATE/DELETE but allows SELECT.
732        struct SelectOnlyAuthorizer;
733
734        impl Authorizer for SelectOnlyAuthorizer {
735            fn authorize(
736                &self,
737                action: AuthAction,
738                _arg1: Option<&str>,
739                arg2: Option<&str>,
740                _db_name: Option<&str>,
741                _trigger: Option<&str>,
742            ) -> AuthResult {
743                match action {
744                    AuthAction::Select | AuthAction::Read => {
745                        // Ignore the "secret" column (replaced with NULL)
746                        if action == AuthAction::Read && arg2 == Some("secret") {
747                            return AuthResult::Ignore;
748                        }
749                        AuthResult::Ok
750                    }
751                    AuthAction::Insert | AuthAction::Update | AuthAction::Delete => {
752                        AuthResult::Deny
753                    }
754                    _ => AuthResult::Ok,
755                }
756            }
757        }
758
759        let auth = SelectOnlyAuthorizer;
760
761        // SELECT is allowed at compile time.
762        assert_eq!(
763            auth.authorize(AuthAction::Select, None, None, Some("main"), None),
764            AuthResult::Ok,
765            "SELECT must be allowed"
766        );
767
768        // INSERT is denied at compile time.
769        assert_eq!(
770            auth.authorize(AuthAction::Insert, Some("users"), None, Some("main"), None),
771            AuthResult::Deny,
772            "INSERT must be denied (compile-time auth error)"
773        );
774
775        // UPDATE is denied.
776        assert_eq!(
777            auth.authorize(
778                AuthAction::Update,
779                Some("users"),
780                Some("email"),
781                Some("main"),
782                None
783            ),
784            AuthResult::Deny,
785        );
786
787        // DELETE is denied.
788        assert_eq!(
789            auth.authorize(AuthAction::Delete, Some("users"), None, Some("main"), None),
790            AuthResult::Deny,
791        );
792
793        // Read on "secret" column returns Ignore (nullify).
794        assert_eq!(
795            auth.authorize(
796                AuthAction::Read,
797                Some("users"),
798                Some("secret"),
799                Some("main"),
800                None
801            ),
802            AuthResult::Ignore,
803            "Ignore must nullify column"
804        );
805
806        // Read on normal column is allowed.
807        assert_eq!(
808            auth.authorize(
809                AuthAction::Read,
810                Some("users"),
811                Some("name"),
812                Some("main"),
813                None
814            ),
815            AuthResult::Ok,
816        );
817    }
818
819    #[test]
820    fn test_e2e_function_registry_resolution() {
821        // Register abs(1 arg) and a variadic version, then test resolution.
822        struct Abs1;
823
824        impl ScalarFunction for Abs1 {
825            fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
826                Ok(SqliteValue::Integer(args[0].to_integer().abs()))
827            }
828
829            fn num_args(&self) -> i32 {
830                1
831            }
832
833            fn name(&self) -> &str {
834                "abs"
835            }
836        }
837
838        struct AbsVariadic;
839
840        impl ScalarFunction for AbsVariadic {
841            fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
842                // Variadic: return sum of absolute values
843                let sum: i64 = args.iter().map(|a| a.to_integer().abs()).sum();
844                Ok(SqliteValue::Integer(sum))
845            }
846
847            fn num_args(&self) -> i32 {
848                -1
849            }
850
851            fn name(&self) -> &str {
852                "abs"
853            }
854        }
855
856        let mut registry = FunctionRegistry::new();
857        registry.register_scalar(Abs1);
858        registry.register_scalar(AbsVariadic);
859
860        // SELECT abs(-5) should use 1-arg version.
861        let f = registry.find_scalar("abs", 1).expect("abs(1) found");
862        assert_eq!(f.num_args(), 1, "exact 1-arg match");
863        assert_eq!(
864            f.invoke(&[SqliteValue::Integer(-5)]).unwrap(),
865            SqliteValue::Integer(5)
866        );
867
868        // SELECT abs(-5, -3) should fall through to variadic.
869        let f = registry.find_scalar("abs", 2).expect("abs variadic found");
870        assert_eq!(f.num_args(), -1, "variadic fallback for 2 args");
871        assert_eq!(
872            f.invoke(&[SqliteValue::Integer(-5), SqliteValue::Integer(-3)])
873                .unwrap(),
874            SqliteValue::Integer(8)
875        );
876
877        // Nonexistent function returns None.
878        assert!(registry.find_scalar("nonexistent", 1).is_none());
879    }
880
881    #[test]
882    fn test_authorizer_called_at_compile_time() {
883        use authorizer::{AuthAction, AuthResult, Authorizer};
884        use std::sync::Mutex;
885
886        // Track every authorize call to verify compile-time invocation pattern.
887        struct TrackingAuthorizer {
888            calls: Mutex<Vec<AuthAction>>,
889        }
890
891        impl TrackingAuthorizer {
892            fn new() -> Self {
893                Self {
894                    calls: Mutex::new(Vec::new()),
895                }
896            }
897        }
898
899        impl Authorizer for TrackingAuthorizer {
900            fn authorize(
901                &self,
902                action: AuthAction,
903                _arg1: Option<&str>,
904                _arg2: Option<&str>,
905                _db_name: Option<&str>,
906                _trigger: Option<&str>,
907            ) -> AuthResult {
908                self.calls.lock().unwrap().push(action);
909                AuthResult::Ok
910            }
911        }
912
913        let auth = TrackingAuthorizer::new();
914
915        // Simulate compile-time authorization for:
916        // `SELECT name, email FROM users WHERE id = ?`
917        //
918        // The authorizer is called during prepare(), NOT during step().
919        // Expected calls:
920        //   1. Select (the statement type)
921        //   2. Read(users, name)
922        //   3. Read(users, email)
923        //   4. Read(users, id)    -- WHERE clause column
924
925        // Phase 1: prepare (compile time) — authorizer is called
926        auth.authorize(AuthAction::Select, None, None, Some("main"), None);
927        auth.authorize(
928            AuthAction::Read,
929            Some("users"),
930            Some("name"),
931            Some("main"),
932            None,
933        );
934        auth.authorize(
935            AuthAction::Read,
936            Some("users"),
937            Some("email"),
938            Some("main"),
939            None,
940        );
941        auth.authorize(
942            AuthAction::Read,
943            Some("users"),
944            Some("id"),
945            Some("main"),
946            None,
947        );
948
949        let calls = auth.calls.lock().unwrap();
950        assert_eq!(calls.len(), 4, "authorizer called 4 times during prepare");
951        assert_eq!(calls[0], AuthAction::Select);
952        assert_eq!(calls[1], AuthAction::Read);
953        assert_eq!(calls[2], AuthAction::Read);
954        assert_eq!(calls[3], AuthAction::Read);
955        drop(calls);
956
957        // Phase 2: step (execution) — authorizer is NOT called again
958        // (In a real implementation, step() would not invoke authorize.)
959        // We simply verify no additional calls were recorded.
960        let calls_after = auth.calls.lock().unwrap();
961        assert_eq!(
962            calls_after.len(),
963            4,
964            "authorizer must NOT be called during step/execution"
965        );
966        drop(calls_after);
967    }
968}