Skip to main content

fsqlite_func/
scalar.rs

1//! Scalar (row-level) function trait.
2//!
3//! Scalar functions compute a single output value from zero or more input
4//! values. They are stateless across rows: each invocation is independent.
5//!
6//! This trait is **open** (user-implementable), unlike the sealed pager/btree
7//! traits. Extension authors implement `ScalarFunction` to register custom
8//! SQL functions.
9//!
10//! # Send + Sync
11//!
12//! Scalar functions may be shared across threads via `Arc` for use by
13//! concurrent query executors. Implementations must be thread-safe.
14//!
15//! # Cx Exception
16//!
17//! `invoke` does **not** take `&Cx` because deterministic scalar functions
18//! are pure computations (§9 cross-cutting rule: "Pure computation
19//! exceptions: deterministic ScalarFunction::invoke without I/O need not
20//! take Cx").
21#![allow(clippy::unnecessary_literal_bound)]
22
23use fsqlite_error::Result;
24use fsqlite_types::SqliteValue;
25
26/// A scalar (row-level) SQL function.
27///
28/// Scalar functions are invoked once per row and return a single value.
29/// They are stored in the [`FunctionRegistry`](crate::FunctionRegistry) as
30/// `Arc<dyn ScalarFunction>`.
31///
32/// # Error Handling
33///
34/// - Return [`FrankenError::FunctionError`](fsqlite_error::FrankenError::FunctionError)
35///   for domain errors (e.g. `abs(i64::MIN)`).
36/// - Return [`FrankenError::TooBig`](fsqlite_error::FrankenError::TooBig)
37///   if the result exceeds `SQLITE_MAX_LENGTH`.
38pub trait ScalarFunction: Send + Sync {
39    /// Execute this function on the given arguments.
40    fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue>;
41
42    /// Whether this function is deterministic (same inputs → same output).
43    ///
44    /// Deterministic functions enable constant folding and other query
45    /// planner optimizations. Defaults to `true`.
46    fn is_deterministic(&self) -> bool {
47        true
48    }
49
50    /// The number of arguments this function accepts.
51    ///
52    /// `-1` means variadic (any number of arguments).
53    fn num_args(&self) -> i32;
54
55    /// The function name, used in error messages and EXPLAIN output.
56    fn name(&self) -> &str;
57}
58
59#[cfg(test)]
60mod tests {
61    use std::sync::Arc;
62
63    use fsqlite_error::FrankenError;
64
65    use super::*;
66
67    // -- Mock: add_one(x) -> x + 1 --
68
69    struct AddOne;
70
71    impl ScalarFunction for AddOne {
72        fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
73            match &args[0] {
74                SqliteValue::Integer(i) => Ok(SqliteValue::Integer(i + 1)),
75                SqliteValue::Float(f) => Ok(SqliteValue::Float(f + 1.0)),
76                SqliteValue::Null => Ok(SqliteValue::Null),
77                SqliteValue::Text(s) => {
78                    let n: i64 = s.parse().unwrap_or(0);
79                    Ok(SqliteValue::Integer(n + 1))
80                }
81                SqliteValue::Blob(_) => Ok(SqliteValue::Integer(1)),
82            }
83        }
84
85        fn num_args(&self) -> i32 {
86            1
87        }
88
89        fn name(&self) -> &str {
90            "add_one"
91        }
92    }
93
94    // -- Mock: non-deterministic --
95
96    struct NonDeterministic;
97
98    impl ScalarFunction for NonDeterministic {
99        fn invoke(&self, _args: &[SqliteValue]) -> Result<SqliteValue> {
100            Ok(SqliteValue::Integer(42))
101        }
102
103        fn is_deterministic(&self) -> bool {
104            false
105        }
106
107        fn num_args(&self) -> i32 {
108            0
109        }
110
111        fn name(&self) -> &str {
112            "random_ish"
113        }
114    }
115
116    // -- Mock: variadic concat --
117
118    struct Concat;
119
120    impl ScalarFunction for Concat {
121        fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
122            let mut result = String::new();
123            for arg in args {
124                result.push_str(&arg.to_text());
125            }
126            Ok(SqliteValue::Text(result.into()))
127        }
128
129        fn num_args(&self) -> i32 {
130            -1
131        }
132
133        fn name(&self) -> &str {
134            "concat"
135        }
136    }
137
138    // -- Mock: domain error --
139
140    struct SafeAbs;
141
142    impl ScalarFunction for SafeAbs {
143        fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
144            match &args[0] {
145                SqliteValue::Integer(i) => {
146                    if *i == i64::MIN {
147                        return Err(FrankenError::function_error("abs(i64::MIN) would overflow"));
148                    }
149                    Ok(SqliteValue::Integer(i.abs()))
150                }
151                _ => Ok(args[0].clone()),
152            }
153        }
154
155        fn num_args(&self) -> i32 {
156            1
157        }
158
159        fn name(&self) -> &str {
160            "safe_abs"
161        }
162    }
163
164    // -- Mock: too-big error --
165
166    struct BigResult;
167
168    impl ScalarFunction for BigResult {
169        fn invoke(&self, _args: &[SqliteValue]) -> Result<SqliteValue> {
170            Err(FrankenError::TooBig)
171        }
172
173        fn num_args(&self) -> i32 {
174            0
175        }
176
177        fn name(&self) -> &str {
178            "big_result"
179        }
180    }
181
182    // -- Tests --
183
184    #[test]
185    fn test_scalar_function_invoke_basic() {
186        let f = AddOne;
187        // Integer
188        assert_eq!(
189            f.invoke(&[SqliteValue::Integer(41)]).unwrap(),
190            SqliteValue::Integer(42)
191        );
192        // Float
193        assert_eq!(
194            f.invoke(&[SqliteValue::Float(1.5)]).unwrap(),
195            SqliteValue::Float(2.5)
196        );
197        // Null
198        assert!(f.invoke(&[SqliteValue::Null]).unwrap().is_null());
199        // Text (numeric coercion)
200        assert_eq!(
201            f.invoke(&[SqliteValue::Text("99".into())]).unwrap(),
202            SqliteValue::Integer(100)
203        );
204    }
205
206    #[test]
207    fn test_scalar_function_deterministic_flag() {
208        let det = AddOne;
209        assert!(det.is_deterministic());
210
211        let non_det = NonDeterministic;
212        assert!(!non_det.is_deterministic());
213    }
214
215    #[test]
216    fn test_scalar_function_variadic() {
217        let f = Concat;
218        assert_eq!(f.num_args(), -1);
219
220        // 0 args
221        assert_eq!(f.invoke(&[]).unwrap(), SqliteValue::Text("".into()));
222
223        // 1 arg
224        assert_eq!(
225            f.invoke(&[SqliteValue::Text("hello".into())]).unwrap(),
226            SqliteValue::Text("hello".into())
227        );
228
229        // many args
230        assert_eq!(
231            f.invoke(&[
232                SqliteValue::Text("a".into()),
233                SqliteValue::Text("b".into()),
234                SqliteValue::Text("c".into()),
235            ])
236            .unwrap(),
237            SqliteValue::Text("abc".into())
238        );
239    }
240
241    #[test]
242    fn test_scalar_function_error_domain() {
243        let f = SafeAbs;
244        let err = f.invoke(&[SqliteValue::Integer(i64::MIN)]).unwrap_err();
245        assert!(
246            matches!(err, FrankenError::FunctionError(ref msg) if msg.contains("overflow")),
247            "expected FunctionError, got {err:?}"
248        );
249    }
250
251    #[test]
252    fn test_scalar_function_too_big_error() {
253        let f = BigResult;
254        let err = f.invoke(&[]).unwrap_err();
255        assert!(matches!(err, FrankenError::TooBig));
256    }
257
258    #[test]
259    fn test_scalar_send_sync() {
260        fn assert_send_sync<T: Send + Sync>() {}
261        assert_send_sync::<AddOne>();
262
263        // Can be stored in Arc
264        let f: Arc<dyn ScalarFunction> = Arc::new(AddOne);
265        let f2 = Arc::clone(&f);
266        let handle = std::thread::spawn(move || f2.invoke(&[SqliteValue::Integer(0)]));
267        let _ = f.invoke(&[SqliteValue::Integer(1)]);
268        handle.join().unwrap().unwrap();
269    }
270}