Skip to main content

fsqlite_func/
scalar.rs

1//! Scalar (row-level) function trait.
2//!
3//! Scalar functions compute a single output value from zero or more input
4//! values. They are stateless across rows: each invocation is independent.
5//!
6//! This trait is **open** (user-implementable), unlike the sealed pager/btree
7//! traits. Extension authors implement `ScalarFunction` to register custom
8//! SQL functions.
9//!
10//! # Send + Sync
11//!
12//! Scalar functions may be shared across threads via `Arc` for use by
13//! concurrent query executors. Implementations must be thread-safe.
14//!
15//! # Cx Exception
16//!
17//! `invoke` does **not** take `&Cx` because deterministic scalar functions
18//! are pure computations (§9 cross-cutting rule: "Pure computation
19//! exceptions: deterministic ScalarFunction::invoke without I/O need not
20//! take Cx").
21#![allow(clippy::unnecessary_literal_bound)]
22
23use fsqlite_error::Result;
24use fsqlite_types::SqliteValue;
25
26/// A scalar (row-level) SQL function.
27///
28/// Scalar functions are invoked once per row and return a single value.
29/// They are stored in the [`FunctionRegistry`](crate::FunctionRegistry) as
30/// `Arc<dyn ScalarFunction>`.
31///
32/// # Error Handling
33///
34/// - Return [`FrankenError::FunctionError`](fsqlite_error::FrankenError::FunctionError)
35///   for domain errors (e.g. `abs(i64::MIN)`).
36/// - Return [`FrankenError::TooBig`](fsqlite_error::FrankenError::TooBig)
37///   if the result exceeds `SQLITE_MAX_LENGTH`.
38pub trait ScalarFunction: Send + Sync {
39    /// Execute this function on the given arguments.
40    fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue>;
41
42    /// Whether this function is deterministic (same inputs → same output).
43    ///
44    /// Deterministic functions enable constant folding and other query
45    /// planner optimizations. Defaults to `true`.
46    fn is_deterministic(&self) -> bool {
47        true
48    }
49
50    /// The number of arguments this function accepts.
51    ///
52    /// `-1` means variadic (any number of arguments).
53    fn num_args(&self) -> i32;
54
55    /// Minimum accepted argument count for variadic functions.
56    ///
57    /// Fixed-arity functions default to their exact arity. Variadic functions
58    /// default to accepting zero arguments unless an implementation tightens
59    /// the bound to match SQLite's function surface.
60    fn min_args(&self) -> i32 {
61        self.num_args().max(0)
62    }
63
64    /// Maximum accepted argument count, or `None` for unbounded variadic
65    /// functions.
66    fn max_args(&self) -> Option<i32> {
67        (self.num_args() >= 0).then(|| self.num_args())
68    }
69
70    /// Return whether this function accepts `num_args` arguments.
71    fn accepts_arg_count(&self, num_args: i32) -> bool {
72        num_args >= self.min_args() && self.max_args().is_none_or(|max| num_args <= max)
73    }
74
75    /// The function name, used in error messages and EXPLAIN output.
76    fn name(&self) -> &str;
77}
78
79#[cfg(test)]
80mod tests {
81    use std::sync::Arc;
82
83    use fsqlite_error::FrankenError;
84
85    use super::*;
86
87    // -- Mock: add_one(x) -> x + 1 --
88
89    struct AddOne;
90
91    impl ScalarFunction for AddOne {
92        fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
93            match &args[0] {
94                SqliteValue::Integer(i) => Ok(SqliteValue::Integer(i + 1)),
95                SqliteValue::Float(f) => Ok(SqliteValue::Float(f + 1.0)),
96                SqliteValue::Null => Ok(SqliteValue::Null),
97                SqliteValue::Text(s) => {
98                    let n: i64 = s.parse().unwrap_or(0);
99                    Ok(SqliteValue::Integer(n + 1))
100                }
101                SqliteValue::Blob(_) => Ok(SqliteValue::Integer(1)),
102            }
103        }
104
105        fn num_args(&self) -> i32 {
106            1
107        }
108
109        fn name(&self) -> &str {
110            "add_one"
111        }
112    }
113
114    // -- Mock: non-deterministic --
115
116    struct NonDeterministic;
117
118    impl ScalarFunction for NonDeterministic {
119        fn invoke(&self, _args: &[SqliteValue]) -> Result<SqliteValue> {
120            Ok(SqliteValue::Integer(42))
121        }
122
123        fn is_deterministic(&self) -> bool {
124            false
125        }
126
127        fn num_args(&self) -> i32 {
128            0
129        }
130
131        fn name(&self) -> &str {
132            "random_ish"
133        }
134    }
135
136    // -- Mock: variadic concat --
137
138    struct Concat;
139
140    impl ScalarFunction for Concat {
141        fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
142            let mut result = String::new();
143            for arg in args {
144                result.push_str(&arg.to_text());
145            }
146            Ok(SqliteValue::Text(result.into()))
147        }
148
149        fn num_args(&self) -> i32 {
150            -1
151        }
152
153        fn name(&self) -> &str {
154            "concat"
155        }
156    }
157
158    // -- Mock: domain error --
159
160    struct SafeAbs;
161
162    impl ScalarFunction for SafeAbs {
163        fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
164            match &args[0] {
165                SqliteValue::Integer(i) => {
166                    if *i == i64::MIN {
167                        return Err(FrankenError::function_error("abs(i64::MIN) would overflow"));
168                    }
169                    Ok(SqliteValue::Integer(i.abs()))
170                }
171                _ => Ok(args[0].clone()),
172            }
173        }
174
175        fn num_args(&self) -> i32 {
176            1
177        }
178
179        fn name(&self) -> &str {
180            "safe_abs"
181        }
182    }
183
184    // -- Mock: too-big error --
185
186    struct BigResult;
187
188    impl ScalarFunction for BigResult {
189        fn invoke(&self, _args: &[SqliteValue]) -> Result<SqliteValue> {
190            Err(FrankenError::TooBig)
191        }
192
193        fn num_args(&self) -> i32 {
194            0
195        }
196
197        fn name(&self) -> &str {
198            "big_result"
199        }
200    }
201
202    // -- Tests --
203
204    #[test]
205    fn test_scalar_function_invoke_basic() {
206        let f = AddOne;
207        // Integer
208        assert_eq!(
209            f.invoke(&[SqliteValue::Integer(41)]).unwrap(),
210            SqliteValue::Integer(42)
211        );
212        // Float
213        assert_eq!(
214            f.invoke(&[SqliteValue::Float(1.5)]).unwrap(),
215            SqliteValue::Float(2.5)
216        );
217        // Null
218        assert!(f.invoke(&[SqliteValue::Null]).unwrap().is_null());
219        // Text (numeric coercion)
220        assert_eq!(
221            f.invoke(&[SqliteValue::Text("99".into())]).unwrap(),
222            SqliteValue::Integer(100)
223        );
224    }
225
226    #[test]
227    fn test_scalar_function_deterministic_flag() {
228        let det = AddOne;
229        assert!(det.is_deterministic());
230
231        let non_det = NonDeterministic;
232        assert!(!non_det.is_deterministic());
233    }
234
235    #[test]
236    fn test_scalar_function_variadic() {
237        let f = Concat;
238        assert_eq!(f.num_args(), -1);
239        assert_eq!(f.min_args(), 0);
240        assert_eq!(f.max_args(), None);
241        assert!(f.accepts_arg_count(0));
242        assert!(f.accepts_arg_count(3));
243
244        // 0 args
245        assert_eq!(f.invoke(&[]).unwrap(), SqliteValue::Text("".into()));
246
247        // 1 arg
248        assert_eq!(
249            f.invoke(&[SqliteValue::Text("hello".into())]).unwrap(),
250            SqliteValue::Text("hello".into())
251        );
252
253        // many args
254        assert_eq!(
255            f.invoke(&[
256                SqliteValue::Text("a".into()),
257                SqliteValue::Text("b".into()),
258                SqliteValue::Text("c".into()),
259            ])
260            .unwrap(),
261            SqliteValue::Text("abc".into())
262        );
263    }
264
265    #[test]
266    fn test_scalar_function_error_domain() {
267        let f = SafeAbs;
268        let err = f.invoke(&[SqliteValue::Integer(i64::MIN)]).unwrap_err();
269        assert!(
270            matches!(err, FrankenError::FunctionError(ref msg) if msg.contains("overflow")),
271            "expected FunctionError, got {err:?}"
272        );
273    }
274
275    #[test]
276    fn test_scalar_function_too_big_error() {
277        let f = BigResult;
278        let err = f.invoke(&[]).unwrap_err();
279        assert!(matches!(err, FrankenError::TooBig));
280    }
281
282    #[test]
283    fn test_scalar_send_sync() {
284        fn assert_send_sync<T: Send + Sync>() {}
285        assert_send_sync::<AddOne>();
286
287        // Can be stored in Arc
288        let f: Arc<dyn ScalarFunction> = Arc::new(AddOne);
289        let f2 = Arc::clone(&f);
290        let handle = std::thread::spawn(move || f2.invoke(&[SqliteValue::Integer(0)]));
291        let _ = f.invoke(&[SqliteValue::Integer(1)]);
292        handle.join().unwrap().unwrap();
293    }
294}