Skip to main content

fsqlite_func/
lib.rs

1//! Built-in SQL function and extension trait surfaces.
2//!
3//! This crate defines open, user-implementable traits for:
4//! - scalar, aggregate, and window functions
5//! - virtual table modules/cursors
6//! - collation callbacks
7//! - authorizer callbacks
8//!
9//! It also provides a small in-memory [`FunctionRegistry`] for registering and
10//! resolving scalar/aggregate/window functions by `(name, num_args)` key with
11//! variadic fallback.
12#![allow(clippy::unnecessary_literal_bound)]
13
14use std::any::Any;
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::sync::atomic::{AtomicU64, Ordering};
18
19use tracing::debug;
20
21// ── Function evaluation metrics (bd-2wt.1) ─────────────────────────────────
22
23/// Total number of scalar function calls across all statements.
24static FSQLITE_FUNC_CALLS_TOTAL: AtomicU64 = AtomicU64::new(0);
25/// Cumulative function evaluation duration in microseconds.
26static FSQLITE_FUNC_EVAL_DURATION_US_TOTAL: AtomicU64 = AtomicU64::new(0);
27
28/// Snapshot of function evaluation metrics.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct FuncMetricsSnapshot {
31    /// Total scalar function calls.
32    pub calls_total: u64,
33    /// Cumulative evaluation duration in microseconds.
34    pub eval_duration_us_total: u64,
35}
36
37/// Read a point-in-time snapshot of function evaluation metrics.
38#[must_use]
39pub fn func_metrics_snapshot() -> FuncMetricsSnapshot {
40    FuncMetricsSnapshot {
41        calls_total: FSQLITE_FUNC_CALLS_TOTAL.load(Ordering::Relaxed),
42        eval_duration_us_total: FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.load(Ordering::Relaxed),
43    }
44}
45
46/// Reset function metrics to zero (tests/diagnostics).
47pub fn reset_func_metrics() {
48    FSQLITE_FUNC_CALLS_TOTAL.store(0, Ordering::Relaxed);
49    FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.store(0, Ordering::Relaxed);
50}
51
52/// Record a function call for metrics (called from VDBE engine).
53pub fn record_func_call(duration_us: u64) {
54    FSQLITE_FUNC_CALLS_TOTAL.fetch_add(1, Ordering::Relaxed);
55    FSQLITE_FUNC_EVAL_DURATION_US_TOTAL.fetch_add(duration_us, Ordering::Relaxed);
56}
57
58/// Record a function call count only, without timing (fast path).
59pub fn record_func_call_count_only() {
60    FSQLITE_FUNC_CALLS_TOTAL.fetch_add(1, Ordering::Relaxed);
61}
62
63// ── UDF registration metrics (bd-2wt.3) ────────────────────────────────
64
65/// Total number of UDF registrations.
66static FSQLITE_UDF_REGISTERED: AtomicU64 = AtomicU64::new(0);
67
68/// Record a UDF registration event.
69pub fn record_udf_registered() {
70    FSQLITE_UDF_REGISTERED.fetch_add(1, Ordering::Relaxed);
71}
72
73/// Current count of UDF registrations.
74#[must_use]
75pub fn udf_registered_count() -> u64 {
76    FSQLITE_UDF_REGISTERED.load(Ordering::Relaxed)
77}
78
79/// Reset UDF registration counter (tests/diagnostics).
80pub fn reset_udf_metrics() {
81    FSQLITE_UDF_REGISTERED.store(0, Ordering::Relaxed);
82}
83
84pub mod agg_builtins;
85pub mod aggregate;
86pub mod authorizer;
87pub mod builtins;
88pub mod collation;
89pub mod datetime;
90pub mod math;
91pub mod scalar;
92pub mod vtab;
93pub mod window;
94pub mod window_builtins;
95
96pub use agg_builtins::register_aggregate_builtins;
97pub use aggregate::{AggregateAdapter, AggregateFunction};
98pub use authorizer::{AuthAction, AuthResult, Authorizer, AuthorizerAction, AuthorizerDecision};
99pub use builtins::{
100    ChangeTrackingState, get_last_changes, get_last_insert_rowid, get_total_changes,
101    register_builtins, reset_total_changes, set_change_tracking_state, set_last_changes,
102    set_last_insert_rowid,
103};
104pub use collation::{
105    BinaryCollation, CollationAnnotation, CollationFunction, CollationRegistry, CollationSource,
106    NoCaseCollation, RtrimCollation, resolve_collation,
107};
108pub use datetime::register_datetime_builtins;
109pub use math::register_math_builtins;
110pub use scalar::ScalarFunction;
111pub use vtab::{
112    ColumnContext, ConstraintOp, IndexConstraint, IndexConstraintUsage, IndexInfo, IndexOrderBy,
113    VirtualTable, VirtualTableCursor,
114};
115pub use window::{WindowAdapter, WindowFunction};
116pub use window_builtins::register_window_builtins;
117
118/// Type-erased aggregate function object used by the registry.
119pub type ErasedAggregateFunction = dyn AggregateFunction<State = Box<dyn Any + Send>>;
120
121/// Type-erased window function object used by the registry.
122pub type ErasedWindowFunction = dyn WindowFunction<State = Box<dyn Any + Send>>;
123
124/// Composite lookup key for functions: `(UPPERCASE name, num_args)`.
125///
126/// `-1` for `num_args` means variadic (any number of arguments).
127/// Names are stored as uppercase ASCII for case-insensitive matching.
128#[derive(Debug, Clone, Hash, Eq, PartialEq)]
129pub struct FunctionKey {
130    /// Function name, stored as uppercase ASCII.
131    pub name: String,
132    /// Expected argument count, or `-1` for variadic.
133    pub num_args: i32,
134}
135
136impl FunctionKey {
137    /// Create a new function key with the name canonicalized to uppercase.
138    #[must_use]
139    pub fn new(name: &str, num_args: i32) -> Self {
140        Self {
141            name: canonical_name(name),
142            num_args,
143        }
144    }
145}
146
147/// Registry for scalar, aggregate, and window functions, keyed by
148/// `(name, num_args)`.
149///
150/// Lookup strategy (§9.5):
151/// 1. Exact match on `(UPPERCASE_NAME, num_args)`.
152/// 2. Fallback to variadic version `(UPPERCASE_NAME, -1)`.
153/// 3. `None` if neither found (caller should raise "no such function").
154#[derive(Default)]
155pub struct FunctionRegistry {
156    scalars: HashMap<FunctionKey, Arc<dyn ScalarFunction>>,
157    aggregates: HashMap<FunctionKey, Arc<ErasedAggregateFunction>>,
158    windows: HashMap<FunctionKey, Arc<ErasedWindowFunction>>,
159}
160
161impl FunctionRegistry {
162    /// Create an empty registry.
163    #[must_use]
164    pub fn new() -> Self {
165        Self::default()
166    }
167
168    /// Create a mutable clone of a registry from an `Arc` reference.
169    ///
170    /// This is used by the UDF registration API to produce a new registry
171    /// containing the existing functions plus the newly registered UDF.
172    #[must_use]
173    pub fn clone_from_arc(arc: &Arc<Self>) -> Self {
174        Self {
175            scalars: arc.scalars.clone(),
176            aggregates: arc.aggregates.clone(),
177            windows: arc.windows.clone(),
178        }
179    }
180
181    /// Register a scalar function, keyed by `(name, num_args)`.
182    ///
183    /// Overwrites any existing function with the same key. Returns the
184    /// previous function if one existed.
185    pub fn register_scalar<F>(&mut self, function: F) -> Option<Arc<dyn ScalarFunction>>
186    where
187        F: ScalarFunction + 'static,
188    {
189        let key = FunctionKey::new(function.name(), function.num_args());
190        self.scalars.insert(key, Arc::new(function))
191    }
192
193    /// Register an aggregate function using the type-erased adapter.
194    ///
195    /// Overwrites any existing function with the same `(name, num_args)` key.
196    pub fn register_aggregate<F>(&mut self, function: F) -> Option<Arc<ErasedAggregateFunction>>
197    where
198        F: AggregateFunction + 'static,
199        F::State: 'static,
200    {
201        let key = FunctionKey::new(function.name(), function.num_args());
202        self.aggregates
203            .insert(key, Arc::new(AggregateAdapter::new(function)))
204    }
205
206    /// Register a window function using the type-erased adapter.
207    ///
208    /// Overwrites any existing function with the same `(name, num_args)` key.
209    pub fn register_window<F>(&mut self, function: F) -> Option<Arc<ErasedWindowFunction>>
210    where
211        F: WindowFunction + 'static,
212        F::State: 'static,
213    {
214        let key = FunctionKey::new(function.name(), function.num_args());
215        self.windows
216            .insert(key, Arc::new(WindowAdapter::new(function)))
217    }
218
219    /// Look up a scalar function by `(name, num_args)`.
220    ///
221    /// Tries exact match first, then falls back to the variadic version
222    /// `(name, -1)` if no exact match exists.
223    #[must_use]
224    pub fn find_scalar(&self, name: &str, num_args: i32) -> Option<Arc<dyn ScalarFunction>> {
225        let canon = canonical_name(name);
226        self.find_scalar_precanonical(&canon, num_args)
227    }
228
229    /// Look up a scalar function by already-uppercased name (avoids allocation).
230    ///
231    /// Used by the VDBE engine where `P4::FuncName` values are already
232    /// canonicalized by codegen.
233    #[must_use]
234    pub fn find_scalar_precanonical(
235        &self,
236        canonical: &str,
237        num_args: i32,
238    ) -> Option<Arc<dyn ScalarFunction>> {
239        let exact = FunctionKey {
240            name: canonical.to_owned(),
241            num_args,
242        };
243        if let Some(f) = self.scalars.get(&exact) {
244            debug!(name = %canonical, arity = num_args, kind = "scalar", hit = "exact", "registry lookup");
245            return Some(Arc::clone(f));
246        }
247        // Variadic fallback
248        let variadic = FunctionKey {
249            name: canonical.to_owned(),
250            num_args: -1,
251        };
252        let result = self.scalars.get(&variadic).map(Arc::clone);
253        debug!(
254            name = %canonical,
255            arity = num_args,
256            kind = "scalar",
257            hit = if result.is_some() { "variadic" } else { "miss" },
258            "registry lookup"
259        );
260        result
261    }
262
263    /// Look up an aggregate function by `(name, num_args)`.
264    ///
265    /// Tries exact match first, then falls back to variadic `(name, -1)`.
266    #[must_use]
267    pub fn find_aggregate(
268        &self,
269        name: &str,
270        num_args: i32,
271    ) -> Option<Arc<ErasedAggregateFunction>> {
272        let canon = canonical_name(name);
273        self.find_aggregate_precanonical(&canon, num_args)
274    }
275
276    /// Look up an aggregate function by already-uppercased name (avoids allocation).
277    #[must_use]
278    pub fn find_aggregate_precanonical(
279        &self,
280        canonical: &str,
281        num_args: i32,
282    ) -> Option<Arc<ErasedAggregateFunction>> {
283        let exact = FunctionKey {
284            name: canonical.to_owned(),
285            num_args,
286        };
287        if let Some(f) = self.aggregates.get(&exact) {
288            debug!(name = %canonical, arity = num_args, kind = "aggregate", hit = "exact", "registry lookup");
289            return Some(Arc::clone(f));
290        }
291        let variadic = FunctionKey {
292            name: canonical.to_owned(),
293            num_args: -1,
294        };
295        let result = self.aggregates.get(&variadic).map(Arc::clone);
296        debug!(
297            name = %canonical,
298            arity = num_args,
299            kind = "aggregate",
300            hit = if result.is_some() { "variadic" } else { "miss" },
301            "registry lookup"
302        );
303        result
304    }
305
306    /// Look up a window function by `(name, num_args)`.
307    ///
308    /// Tries exact match first, then falls back to variadic `(name, -1)`.
309    #[must_use]
310    pub fn find_window(&self, name: &str, num_args: i32) -> Option<Arc<ErasedWindowFunction>> {
311        let canon = canonical_name(name);
312        let exact = FunctionKey {
313            name: canon.clone(),
314            num_args,
315        };
316        if let Some(f) = self.windows.get(&exact) {
317            debug!(name = %canon, arity = num_args, kind = "window", hit = "exact", "registry lookup");
318            return Some(Arc::clone(f));
319        }
320        let variadic = FunctionKey {
321            name: canon.clone(),
322            num_args: -1,
323        };
324        let result = self.windows.get(&variadic).map(Arc::clone);
325        debug!(
326            name = %canon,
327            arity = num_args,
328            kind = "window",
329            hit = if result.is_some() { "variadic" } else { "miss" },
330            "registry lookup"
331        );
332        result
333    }
334
335    /// Whether the registry contains any scalar function with this name
336    /// (any arg count).
337    #[must_use]
338    pub fn contains_scalar(&self, name: &str) -> bool {
339        let canon = canonical_name(name);
340        self.scalars.keys().any(|k| k.name == canon)
341    }
342
343    /// Whether the registry contains any aggregate function with this name
344    /// (any arg count).
345    #[must_use]
346    pub fn contains_aggregate(&self, name: &str) -> bool {
347        let canon = canonical_name(name);
348        self.aggregates.keys().any(|k| k.name == canon)
349    }
350
351    /// Whether the registry contains any window function with this name
352    /// (any arg count).
353    #[must_use]
354    pub fn contains_window(&self, name: &str) -> bool {
355        let canon = canonical_name(name);
356        self.windows.keys().any(|k| k.name == canon)
357    }
358
359    /// Return deduplicated lowercase names of all registered aggregate functions.
360    ///
361    /// Used by the codegen thread-local to recognize custom aggregate UDFs.
362    #[must_use]
363    pub fn aggregate_names_lowercase(&self) -> Vec<String> {
364        let mut names: Vec<String> = self
365            .aggregates
366            .keys()
367            .map(|k| k.name.to_ascii_lowercase())
368            .collect();
369        names.sort();
370        names.dedup();
371        names
372    }
373}
374
375fn canonical_name(name: &str) -> String {
376    name.trim().to_ascii_uppercase()
377}
378
379#[cfg(test)]
380mod tests {
381    use fsqlite_types::SqliteValue;
382
383    use super::*;
384
385    // -- Mock: double(x) -> x * 2, fixed 1-arg --
386
387    struct Double;
388
389    impl ScalarFunction for Double {
390        fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
391            Ok(SqliteValue::Integer(args[0].to_integer() * 2))
392        }
393
394        fn num_args(&self) -> i32 {
395            1
396        }
397
398        fn name(&self) -> &str {
399            "double"
400        }
401    }
402
403    // -- Mock: variadic concat --
404
405    struct VariadicConcat;
406
407    impl ScalarFunction for VariadicConcat {
408        fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
409            let mut out = String::new();
410            for a in args {
411                out.push_str(&a.to_text());
412            }
413            Ok(SqliteValue::Text(out.into()))
414        }
415
416        fn num_args(&self) -> i32 {
417            -1
418        }
419
420        fn name(&self) -> &str {
421            "my_func"
422        }
423    }
424
425    // -- Mock: fixed 2-arg version of same name --
426
427    struct TwoArgFunc;
428
429    impl ScalarFunction for TwoArgFunc {
430        fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
431            Ok(SqliteValue::Integer(
432                args[0].to_integer() + args[1].to_integer(),
433            ))
434        }
435
436        fn num_args(&self) -> i32 {
437            2
438        }
439
440        fn name(&self) -> &str {
441            "my_func"
442        }
443    }
444
445    struct Product;
446
447    impl AggregateFunction for Product {
448        type State = i64;
449
450        fn initial_state(&self) -> Self::State {
451            1
452        }
453
454        fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> fsqlite_error::Result<()> {
455            *state *= args[0].to_integer();
456            Ok(())
457        }
458
459        fn finalize(&self, state: Self::State) -> fsqlite_error::Result<SqliteValue> {
460            Ok(SqliteValue::Integer(state))
461        }
462
463        fn num_args(&self) -> i32 {
464            1
465        }
466
467        fn name(&self) -> &str {
468            "product"
469        }
470    }
471
472    struct MovingSum;
473
474    impl WindowFunction for MovingSum {
475        type State = i64;
476
477        fn initial_state(&self) -> Self::State {
478            0
479        }
480
481        fn step(&self, state: &mut Self::State, args: &[SqliteValue]) -> fsqlite_error::Result<()> {
482            *state += args[0].to_integer();
483            Ok(())
484        }
485
486        fn inverse(
487            &self,
488            state: &mut Self::State,
489            args: &[SqliteValue],
490        ) -> fsqlite_error::Result<()> {
491            *state -= args[0].to_integer();
492            Ok(())
493        }
494
495        fn value(&self, state: &Self::State) -> fsqlite_error::Result<SqliteValue> {
496            Ok(SqliteValue::Integer(*state))
497        }
498
499        fn finalize(&self, state: Self::State) -> fsqlite_error::Result<SqliteValue> {
500            Ok(SqliteValue::Integer(state))
501        }
502
503        fn num_args(&self) -> i32 {
504            1
505        }
506
507        fn name(&self) -> &str {
508            "moving_sum"
509        }
510    }
511
512    #[test]
513    fn test_registry_register_scalar() {
514        let mut registry = FunctionRegistry::new();
515        let previous = registry.register_scalar(Double);
516        assert!(previous.is_none());
517        assert!(registry.contains_scalar("double"));
518        assert!(registry.contains_scalar("DOUBLE"));
519        let f = registry
520            .find_scalar(" Double ", 1)
521            .expect("double registered");
522        assert_eq!(
523            f.invoke(&[SqliteValue::Integer(21)])
524                .expect("invoke succeeds"),
525            SqliteValue::Integer(42)
526        );
527    }
528
529    #[test]
530    fn test_registry_case_insensitive_lookup() {
531        let mut registry = FunctionRegistry::new();
532        registry.register_scalar(Double);
533
534        // Register as "double", look up as "DOUBLE", "Double", " double "
535        assert!(registry.find_scalar("DOUBLE", 1).is_some());
536        assert!(registry.find_scalar("Double", 1).is_some());
537        assert!(registry.find_scalar(" double ", 1).is_some());
538    }
539
540    #[test]
541    fn test_registry_overwrite() {
542        let mut registry = FunctionRegistry::new();
543
544        // Register first version
545        let prev = registry.register_scalar(Double);
546        assert!(prev.is_none());
547
548        // Register second version with same (name, num_args) — overwrites
549        let prev = registry.register_scalar(Double);
550        assert!(prev.is_some());
551
552        // Still works
553        let f = registry.find_scalar("double", 1).unwrap();
554        assert_eq!(
555            f.invoke(&[SqliteValue::Integer(5)]).unwrap(),
556            SqliteValue::Integer(10)
557        );
558    }
559
560    #[test]
561    fn test_registry_variadic_fallback() {
562        let mut registry = FunctionRegistry::new();
563
564        // Register only the variadic version (num_args = -1)
565        registry.register_scalar(VariadicConcat);
566
567        // Look up with specific arg count — no exact match, falls back to variadic
568        let f = registry
569            .find_scalar("my_func", 3)
570            .expect("variadic fallback");
571        assert_eq!(
572            f.invoke(&[
573                SqliteValue::Text("a".into()),
574                SqliteValue::Text("b".into()),
575                SqliteValue::Text("c".into()),
576            ])
577            .unwrap(),
578            SqliteValue::Text("abc".into())
579        );
580    }
581
582    #[test]
583    fn test_registry_exact_match_over_variadic() {
584        let mut registry = FunctionRegistry::new();
585
586        // Register both variadic (num_args=-1) and exact 2-arg version
587        registry.register_scalar(VariadicConcat);
588        registry.register_scalar(TwoArgFunc);
589
590        // Look up with num_args=2 — exact match wins over variadic
591        let f = registry
592            .find_scalar("my_func", 2)
593            .expect("exact match found");
594        assert_eq!(
595            f.invoke(&[SqliteValue::Integer(10), SqliteValue::Integer(32)])
596                .unwrap(),
597            SqliteValue::Integer(42)
598        );
599
600        // Look up with num_args=5 — no exact match, falls back to variadic
601        let f = registry
602            .find_scalar("my_func", 5)
603            .expect("variadic fallback");
604        assert_eq!(f.num_args(), -1);
605    }
606
607    #[test]
608    fn test_registry_not_found_returns_none() {
609        let registry = FunctionRegistry::new();
610        assert!(registry.find_scalar("nonexistent", 1).is_none());
611        assert!(registry.find_aggregate("nonexistent", 1).is_none());
612        assert!(registry.find_window("nonexistent", 1).is_none());
613    }
614
615    #[test]
616    fn test_registry_register_and_resolve_aggregate() {
617        let mut registry = FunctionRegistry::new();
618        let previous = registry.register_aggregate(Product);
619        assert!(previous.is_none());
620        assert!(registry.contains_aggregate("product"));
621        let f = registry
622            .find_aggregate("PRODUCT", 1)
623            .expect("product aggregate registered");
624
625        let mut state = f.initial_state();
626        f.step(&mut state, &[SqliteValue::Integer(2)])
627            .expect("step 1");
628        f.step(&mut state, &[SqliteValue::Integer(3)])
629            .expect("step 2");
630        f.step(&mut state, &[SqliteValue::Integer(7)])
631            .expect("step 3");
632
633        assert_eq!(
634            f.finalize(state).expect("finalize succeeds"),
635            SqliteValue::Integer(42)
636        );
637    }
638
639    #[test]
640    fn test_registry_aggregate_type_erased() {
641        let mut registry = FunctionRegistry::new();
642        registry.register_aggregate(Product);
643
644        // Round-trip through type-erased registry
645        let f = registry
646            .find_aggregate("product", 1)
647            .expect("product found");
648        let mut state = f.initial_state();
649        f.step(&mut state, &[SqliteValue::Integer(6)]).unwrap();
650        f.step(&mut state, &[SqliteValue::Integer(7)]).unwrap();
651        assert_eq!(f.finalize(state).unwrap(), SqliteValue::Integer(42));
652        assert_eq!(f.name(), "product");
653    }
654
655    #[test]
656    fn test_registry_register_and_resolve_window() {
657        let mut registry = FunctionRegistry::new();
658        let previous = registry.register_window(MovingSum);
659        assert!(previous.is_none());
660        assert!(registry.contains_window("moving_sum"));
661        let f = registry
662            .find_window("MOVING_SUM", 1)
663            .expect("moving_sum window registered");
664
665        let mut state = f.initial_state();
666        f.step(&mut state, &[SqliteValue::Integer(10)])
667            .expect("step 1");
668        f.step(&mut state, &[SqliteValue::Integer(20)])
669            .expect("step 2");
670        f.step(&mut state, &[SqliteValue::Integer(30)])
671            .expect("step 3");
672        assert_eq!(f.value(&state).expect("value"), SqliteValue::Integer(60));
673
674        f.inverse(&mut state, &[SqliteValue::Integer(10)])
675            .expect("inverse 1");
676        f.step(&mut state, &[SqliteValue::Integer(40)])
677            .expect("step 4");
678        assert_eq!(f.value(&state).expect("value"), SqliteValue::Integer(90));
679    }
680
681    #[test]
682    fn test_registry_window_type_erased() {
683        let mut registry = FunctionRegistry::new();
684        registry.register_window(MovingSum);
685
686        let f = registry
687            .find_window("moving_sum", 1)
688            .expect("moving_sum found");
689
690        // Full lifecycle: initial_state -> step -> inverse -> value -> finalize
691        let mut state = f.initial_state();
692        f.step(&mut state, &[SqliteValue::Integer(100)]).unwrap();
693        assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(100));
694
695        f.inverse(&mut state, &[SqliteValue::Integer(100)]).unwrap();
696        assert_eq!(f.value(&state).unwrap(), SqliteValue::Integer(0));
697
698        f.step(&mut state, &[SqliteValue::Integer(42)]).unwrap();
699        assert_eq!(f.finalize(state).unwrap(), SqliteValue::Integer(42));
700    }
701
702    #[test]
703    fn test_function_key_equality() {
704        let k1 = FunctionKey::new("ABS", 1);
705        let k2 = FunctionKey::new("abs", 1);
706        let k3 = FunctionKey::new("ABS", 2);
707
708        assert_eq!(k1, k2, "case-insensitive equality");
709        assert_ne!(k1, k3, "different num_args");
710    }
711
712    // ── E2E: bd-1dc9 ────────────────────────────────────────────────────
713
714    #[test]
715    fn test_e2e_custom_collation_in_order_by() {
716        use collation::{BinaryCollation, CollationFunction, NoCaseCollation, RtrimCollation};
717
718        // Simulate ORDER BY with a custom reverse-alphabetical collation.
719        struct ReverseAlpha;
720
721        impl CollationFunction for ReverseAlpha {
722            fn name(&self) -> &str {
723                "REVERSE_ALPHA"
724            }
725
726            fn compare(&self, left: &[u8], right: &[u8]) -> std::cmp::Ordering {
727                // Reverse of BINARY
728                right.cmp(left)
729            }
730        }
731
732        let coll = ReverseAlpha;
733        let mut data: Vec<&[u8]> = vec![b"banana", b"apple", b"cherry", b"date"];
734        data.sort_by(|a, b| coll.compare(a, b));
735
736        // Reverse alphabetical: date > cherry > banana > apple
737        let expected: Vec<&[u8]> = vec![b"date", b"cherry", b"banana", b"apple"];
738        assert_eq!(data, expected);
739        assert_eq!(coll.name(), "REVERSE_ALPHA");
740
741        // Verify built-in collations are usable as trait objects.
742        let collations: Vec<Box<dyn CollationFunction>> = vec![
743            Box::new(BinaryCollation),
744            Box::new(NoCaseCollation),
745            Box::new(RtrimCollation),
746            Box::new(ReverseAlpha),
747        ];
748        assert_eq!(collations.len(), 4);
749
750        // Sort with BINARY: normal alphabetical
751        let mut binary_sorted = data.clone();
752        binary_sorted.sort_by(|a, b| collations[0].compare(a, b));
753        assert_eq!(binary_sorted[0], b"apple");
754    }
755
756    #[test]
757    fn test_e2e_authorizer_sandboxing() {
758        use authorizer::{AuthAction, AuthResult, Authorizer};
759
760        // Authorizer that denies INSERT/UPDATE/DELETE but allows SELECT.
761        struct SelectOnlyAuthorizer;
762
763        impl Authorizer for SelectOnlyAuthorizer {
764            fn authorize(
765                &self,
766                action: AuthAction,
767                _arg1: Option<&str>,
768                arg2: Option<&str>,
769                _db_name: Option<&str>,
770                _trigger: Option<&str>,
771            ) -> AuthResult {
772                match action {
773                    AuthAction::Select | AuthAction::Read => {
774                        // Ignore the "secret" column (replaced with NULL)
775                        if action == AuthAction::Read && arg2 == Some("secret") {
776                            return AuthResult::Ignore;
777                        }
778                        AuthResult::Ok
779                    }
780                    AuthAction::Insert | AuthAction::Update | AuthAction::Delete => {
781                        AuthResult::Deny
782                    }
783                    _ => AuthResult::Ok,
784                }
785            }
786        }
787
788        let auth = SelectOnlyAuthorizer;
789
790        // SELECT is allowed at compile time.
791        assert_eq!(
792            auth.authorize(AuthAction::Select, None, None, Some("main"), None),
793            AuthResult::Ok,
794            "SELECT must be allowed"
795        );
796
797        // INSERT is denied at compile time.
798        assert_eq!(
799            auth.authorize(AuthAction::Insert, Some("users"), None, Some("main"), None),
800            AuthResult::Deny,
801            "INSERT must be denied (compile-time auth error)"
802        );
803
804        // UPDATE is denied.
805        assert_eq!(
806            auth.authorize(
807                AuthAction::Update,
808                Some("users"),
809                Some("email"),
810                Some("main"),
811                None
812            ),
813            AuthResult::Deny,
814        );
815
816        // DELETE is denied.
817        assert_eq!(
818            auth.authorize(AuthAction::Delete, Some("users"), None, Some("main"), None),
819            AuthResult::Deny,
820        );
821
822        // Read on "secret" column returns Ignore (nullify).
823        assert_eq!(
824            auth.authorize(
825                AuthAction::Read,
826                Some("users"),
827                Some("secret"),
828                Some("main"),
829                None
830            ),
831            AuthResult::Ignore,
832            "Ignore must nullify column"
833        );
834
835        // Read on normal column is allowed.
836        assert_eq!(
837            auth.authorize(
838                AuthAction::Read,
839                Some("users"),
840                Some("name"),
841                Some("main"),
842                None
843            ),
844            AuthResult::Ok,
845        );
846    }
847
848    #[test]
849    fn test_e2e_function_registry_resolution() {
850        // Register abs(1 arg) and a variadic version, then test resolution.
851        struct Abs1;
852
853        impl ScalarFunction for Abs1 {
854            fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
855                Ok(SqliteValue::Integer(args[0].to_integer().abs()))
856            }
857
858            fn num_args(&self) -> i32 {
859                1
860            }
861
862            fn name(&self) -> &str {
863                "abs"
864            }
865        }
866
867        struct AbsVariadic;
868
869        impl ScalarFunction for AbsVariadic {
870            fn invoke(&self, args: &[SqliteValue]) -> fsqlite_error::Result<SqliteValue> {
871                // Variadic: return sum of absolute values
872                let sum: i64 = args.iter().map(|a| a.to_integer().abs()).sum();
873                Ok(SqliteValue::Integer(sum))
874            }
875
876            fn num_args(&self) -> i32 {
877                -1
878            }
879
880            fn name(&self) -> &str {
881                "abs"
882            }
883        }
884
885        let mut registry = FunctionRegistry::new();
886        registry.register_scalar(Abs1);
887        registry.register_scalar(AbsVariadic);
888
889        // SELECT abs(-5) should use 1-arg version.
890        let f = registry.find_scalar("abs", 1).expect("abs(1) found");
891        assert_eq!(f.num_args(), 1, "exact 1-arg match");
892        assert_eq!(
893            f.invoke(&[SqliteValue::Integer(-5)]).unwrap(),
894            SqliteValue::Integer(5)
895        );
896
897        // SELECT abs(-5, -3) should fall through to variadic.
898        let f = registry.find_scalar("abs", 2).expect("abs variadic found");
899        assert_eq!(f.num_args(), -1, "variadic fallback for 2 args");
900        assert_eq!(
901            f.invoke(&[SqliteValue::Integer(-5), SqliteValue::Integer(-3)])
902                .unwrap(),
903            SqliteValue::Integer(8)
904        );
905
906        // Nonexistent function returns None.
907        assert!(registry.find_scalar("nonexistent", 1).is_none());
908    }
909
910    #[test]
911    fn test_authorizer_called_at_compile_time() {
912        use authorizer::{AuthAction, AuthResult, Authorizer};
913        use std::sync::Mutex;
914
915        // Track every authorize call to verify compile-time invocation pattern.
916        struct TrackingAuthorizer {
917            calls: Mutex<Vec<AuthAction>>,
918        }
919
920        impl TrackingAuthorizer {
921            fn new() -> Self {
922                Self {
923                    calls: Mutex::new(Vec::new()),
924                }
925            }
926        }
927
928        impl Authorizer for TrackingAuthorizer {
929            fn authorize(
930                &self,
931                action: AuthAction,
932                _arg1: Option<&str>,
933                _arg2: Option<&str>,
934                _db_name: Option<&str>,
935                _trigger: Option<&str>,
936            ) -> AuthResult {
937                self.calls.lock().unwrap().push(action);
938                AuthResult::Ok
939            }
940        }
941
942        let auth = TrackingAuthorizer::new();
943
944        // Simulate compile-time authorization for:
945        // `SELECT name, email FROM users WHERE id = ?`
946        //
947        // The authorizer is called during prepare(), NOT during step().
948        // Expected calls:
949        //   1. Select (the statement type)
950        //   2. Read(users, name)
951        //   3. Read(users, email)
952        //   4. Read(users, id)    -- WHERE clause column
953
954        // Phase 1: prepare (compile time) — authorizer is called
955        auth.authorize(AuthAction::Select, None, None, Some("main"), None);
956        auth.authorize(
957            AuthAction::Read,
958            Some("users"),
959            Some("name"),
960            Some("main"),
961            None,
962        );
963        auth.authorize(
964            AuthAction::Read,
965            Some("users"),
966            Some("email"),
967            Some("main"),
968            None,
969        );
970        auth.authorize(
971            AuthAction::Read,
972            Some("users"),
973            Some("id"),
974            Some("main"),
975            None,
976        );
977
978        let calls = auth.calls.lock().unwrap();
979        assert_eq!(calls.len(), 4, "authorizer called 4 times during prepare");
980        assert_eq!(calls[0], AuthAction::Select);
981        assert_eq!(calls[1], AuthAction::Read);
982        assert_eq!(calls[2], AuthAction::Read);
983        assert_eq!(calls[3], AuthAction::Read);
984        drop(calls);
985
986        // Phase 2: step (execution) — authorizer is NOT called again
987        // (In a real implementation, step() would not invoke authorize.)
988        // We simply verify no additional calls were recorded.
989        let calls_after = auth.calls.lock().unwrap();
990        assert_eq!(
991            calls_after.len(),
992            4,
993            "authorizer must NOT be called during step/execution"
994        );
995        drop(calls_after);
996    }
997}