1#![allow(clippy::unnecessary_literal_bound)]
22
23use fsqlite_error::Result;
24use fsqlite_types::SqliteValue;
25
26pub trait ScalarFunction: Send + Sync {
39 fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue>;
41
42 fn is_deterministic(&self) -> bool {
47 true
48 }
49
50 fn num_args(&self) -> i32;
54
55 fn min_args(&self) -> i32 {
61 self.num_args().max(0)
62 }
63
64 fn max_args(&self) -> Option<i32> {
67 (self.num_args() >= 0).then(|| self.num_args())
68 }
69
70 fn accepts_arg_count(&self, num_args: i32) -> bool {
72 num_args >= self.min_args() && self.max_args().is_none_or(|max| num_args <= max)
73 }
74
75 fn name(&self) -> &str;
77}
78
79#[cfg(test)]
80mod tests {
81 use std::sync::Arc;
82
83 use fsqlite_error::FrankenError;
84
85 use super::*;
86
87 struct AddOne;
90
91 impl ScalarFunction for AddOne {
92 fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
93 match &args[0] {
94 SqliteValue::Integer(i) => Ok(SqliteValue::Integer(i + 1)),
95 SqliteValue::Float(f) => Ok(SqliteValue::Float(f + 1.0)),
96 SqliteValue::Null => Ok(SqliteValue::Null),
97 SqliteValue::Text(s) => {
98 let n: i64 = s.parse().unwrap_or(0);
99 Ok(SqliteValue::Integer(n + 1))
100 }
101 SqliteValue::Blob(_) => Ok(SqliteValue::Integer(1)),
102 }
103 }
104
105 fn num_args(&self) -> i32 {
106 1
107 }
108
109 fn name(&self) -> &str {
110 "add_one"
111 }
112 }
113
114 struct NonDeterministic;
117
118 impl ScalarFunction for NonDeterministic {
119 fn invoke(&self, _args: &[SqliteValue]) -> Result<SqliteValue> {
120 Ok(SqliteValue::Integer(42))
121 }
122
123 fn is_deterministic(&self) -> bool {
124 false
125 }
126
127 fn num_args(&self) -> i32 {
128 0
129 }
130
131 fn name(&self) -> &str {
132 "random_ish"
133 }
134 }
135
136 struct Concat;
139
140 impl ScalarFunction for Concat {
141 fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
142 let mut result = String::new();
143 for arg in args {
144 result.push_str(&arg.to_text());
145 }
146 Ok(SqliteValue::Text(result.into()))
147 }
148
149 fn num_args(&self) -> i32 {
150 -1
151 }
152
153 fn name(&self) -> &str {
154 "concat"
155 }
156 }
157
158 struct SafeAbs;
161
162 impl ScalarFunction for SafeAbs {
163 fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
164 match &args[0] {
165 SqliteValue::Integer(i) => {
166 if *i == i64::MIN {
167 return Err(FrankenError::function_error("abs(i64::MIN) would overflow"));
168 }
169 Ok(SqliteValue::Integer(i.abs()))
170 }
171 _ => Ok(args[0].clone()),
172 }
173 }
174
175 fn num_args(&self) -> i32 {
176 1
177 }
178
179 fn name(&self) -> &str {
180 "safe_abs"
181 }
182 }
183
184 struct BigResult;
187
188 impl ScalarFunction for BigResult {
189 fn invoke(&self, _args: &[SqliteValue]) -> Result<SqliteValue> {
190 Err(FrankenError::TooBig)
191 }
192
193 fn num_args(&self) -> i32 {
194 0
195 }
196
197 fn name(&self) -> &str {
198 "big_result"
199 }
200 }
201
202 #[test]
205 fn test_scalar_function_invoke_basic() {
206 let f = AddOne;
207 assert_eq!(
209 f.invoke(&[SqliteValue::Integer(41)]).unwrap(),
210 SqliteValue::Integer(42)
211 );
212 assert_eq!(
214 f.invoke(&[SqliteValue::Float(1.5)]).unwrap(),
215 SqliteValue::Float(2.5)
216 );
217 assert!(f.invoke(&[SqliteValue::Null]).unwrap().is_null());
219 assert_eq!(
221 f.invoke(&[SqliteValue::Text("99".into())]).unwrap(),
222 SqliteValue::Integer(100)
223 );
224 }
225
226 #[test]
227 fn test_scalar_function_deterministic_flag() {
228 let det = AddOne;
229 assert!(det.is_deterministic());
230
231 let non_det = NonDeterministic;
232 assert!(!non_det.is_deterministic());
233 }
234
235 #[test]
236 fn test_scalar_function_variadic() {
237 let f = Concat;
238 assert_eq!(f.num_args(), -1);
239 assert_eq!(f.min_args(), 0);
240 assert_eq!(f.max_args(), None);
241 assert!(f.accepts_arg_count(0));
242 assert!(f.accepts_arg_count(3));
243
244 assert_eq!(f.invoke(&[]).unwrap(), SqliteValue::Text("".into()));
246
247 assert_eq!(
249 f.invoke(&[SqliteValue::Text("hello".into())]).unwrap(),
250 SqliteValue::Text("hello".into())
251 );
252
253 assert_eq!(
255 f.invoke(&[
256 SqliteValue::Text("a".into()),
257 SqliteValue::Text("b".into()),
258 SqliteValue::Text("c".into()),
259 ])
260 .unwrap(),
261 SqliteValue::Text("abc".into())
262 );
263 }
264
265 #[test]
266 fn test_scalar_function_error_domain() {
267 let f = SafeAbs;
268 let err = f.invoke(&[SqliteValue::Integer(i64::MIN)]).unwrap_err();
269 assert!(
270 matches!(err, FrankenError::FunctionError(ref msg) if msg.contains("overflow")),
271 "expected FunctionError, got {err:?}"
272 );
273 }
274
275 #[test]
276 fn test_scalar_function_too_big_error() {
277 let f = BigResult;
278 let err = f.invoke(&[]).unwrap_err();
279 assert!(matches!(err, FrankenError::TooBig));
280 }
281
282 #[test]
283 fn test_scalar_send_sync() {
284 fn assert_send_sync<T: Send + Sync>() {}
285 assert_send_sync::<AddOne>();
286
287 let f: Arc<dyn ScalarFunction> = Arc::new(AddOne);
289 let f2 = Arc::clone(&f);
290 let handle = std::thread::spawn(move || f2.invoke(&[SqliteValue::Integer(0)]));
291 let _ = f.invoke(&[SqliteValue::Integer(1)]);
292 handle.join().unwrap().unwrap();
293 }
294}