1#![allow(clippy::unnecessary_literal_bound)]
22
23use fsqlite_error::Result;
24use fsqlite_types::SqliteValue;
25
26pub trait ScalarFunction: Send + Sync {
39 fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue>;
41
42 fn is_deterministic(&self) -> bool {
47 true
48 }
49
50 fn num_args(&self) -> i32;
54
55 fn name(&self) -> &str;
57}
58
59#[cfg(test)]
60mod tests {
61 use std::sync::Arc;
62
63 use fsqlite_error::FrankenError;
64
65 use super::*;
66
67 struct AddOne;
70
71 impl ScalarFunction for AddOne {
72 fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
73 match &args[0] {
74 SqliteValue::Integer(i) => Ok(SqliteValue::Integer(i + 1)),
75 SqliteValue::Float(f) => Ok(SqliteValue::Float(f + 1.0)),
76 SqliteValue::Null => Ok(SqliteValue::Null),
77 SqliteValue::Text(s) => {
78 let n: i64 = s.parse().unwrap_or(0);
79 Ok(SqliteValue::Integer(n + 1))
80 }
81 SqliteValue::Blob(_) => Ok(SqliteValue::Integer(1)),
82 }
83 }
84
85 fn num_args(&self) -> i32 {
86 1
87 }
88
89 fn name(&self) -> &str {
90 "add_one"
91 }
92 }
93
94 struct NonDeterministic;
97
98 impl ScalarFunction for NonDeterministic {
99 fn invoke(&self, _args: &[SqliteValue]) -> Result<SqliteValue> {
100 Ok(SqliteValue::Integer(42))
101 }
102
103 fn is_deterministic(&self) -> bool {
104 false
105 }
106
107 fn num_args(&self) -> i32 {
108 0
109 }
110
111 fn name(&self) -> &str {
112 "random_ish"
113 }
114 }
115
116 struct Concat;
119
120 impl ScalarFunction for Concat {
121 fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
122 let mut result = String::new();
123 for arg in args {
124 result.push_str(&arg.to_text());
125 }
126 Ok(SqliteValue::Text(result.into()))
127 }
128
129 fn num_args(&self) -> i32 {
130 -1
131 }
132
133 fn name(&self) -> &str {
134 "concat"
135 }
136 }
137
138 struct SafeAbs;
141
142 impl ScalarFunction for SafeAbs {
143 fn invoke(&self, args: &[SqliteValue]) -> Result<SqliteValue> {
144 match &args[0] {
145 SqliteValue::Integer(i) => {
146 if *i == i64::MIN {
147 return Err(FrankenError::function_error("abs(i64::MIN) would overflow"));
148 }
149 Ok(SqliteValue::Integer(i.abs()))
150 }
151 _ => Ok(args[0].clone()),
152 }
153 }
154
155 fn num_args(&self) -> i32 {
156 1
157 }
158
159 fn name(&self) -> &str {
160 "safe_abs"
161 }
162 }
163
164 struct BigResult;
167
168 impl ScalarFunction for BigResult {
169 fn invoke(&self, _args: &[SqliteValue]) -> Result<SqliteValue> {
170 Err(FrankenError::TooBig)
171 }
172
173 fn num_args(&self) -> i32 {
174 0
175 }
176
177 fn name(&self) -> &str {
178 "big_result"
179 }
180 }
181
182 #[test]
185 fn test_scalar_function_invoke_basic() {
186 let f = AddOne;
187 assert_eq!(
189 f.invoke(&[SqliteValue::Integer(41)]).unwrap(),
190 SqliteValue::Integer(42)
191 );
192 assert_eq!(
194 f.invoke(&[SqliteValue::Float(1.5)]).unwrap(),
195 SqliteValue::Float(2.5)
196 );
197 assert!(f.invoke(&[SqliteValue::Null]).unwrap().is_null());
199 assert_eq!(
201 f.invoke(&[SqliteValue::Text("99".into())]).unwrap(),
202 SqliteValue::Integer(100)
203 );
204 }
205
206 #[test]
207 fn test_scalar_function_deterministic_flag() {
208 let det = AddOne;
209 assert!(det.is_deterministic());
210
211 let non_det = NonDeterministic;
212 assert!(!non_det.is_deterministic());
213 }
214
215 #[test]
216 fn test_scalar_function_variadic() {
217 let f = Concat;
218 assert_eq!(f.num_args(), -1);
219
220 assert_eq!(f.invoke(&[]).unwrap(), SqliteValue::Text("".into()));
222
223 assert_eq!(
225 f.invoke(&[SqliteValue::Text("hello".into())]).unwrap(),
226 SqliteValue::Text("hello".into())
227 );
228
229 assert_eq!(
231 f.invoke(&[
232 SqliteValue::Text("a".into()),
233 SqliteValue::Text("b".into()),
234 SqliteValue::Text("c".into()),
235 ])
236 .unwrap(),
237 SqliteValue::Text("abc".into())
238 );
239 }
240
241 #[test]
242 fn test_scalar_function_error_domain() {
243 let f = SafeAbs;
244 let err = f.invoke(&[SqliteValue::Integer(i64::MIN)]).unwrap_err();
245 assert!(
246 matches!(err, FrankenError::FunctionError(ref msg) if msg.contains("overflow")),
247 "expected FunctionError, got {err:?}"
248 );
249 }
250
251 #[test]
252 fn test_scalar_function_too_big_error() {
253 let f = BigResult;
254 let err = f.invoke(&[]).unwrap_err();
255 assert!(matches!(err, FrankenError::TooBig));
256 }
257
258 #[test]
259 fn test_scalar_send_sync() {
260 fn assert_send_sync<T: Send + Sync>() {}
261 assert_send_sync::<AddOne>();
262
263 let f: Arc<dyn ScalarFunction> = Arc::new(AddOne);
265 let f2 = Arc::clone(&f);
266 let handle = std::thread::spawn(move || f2.invoke(&[SqliteValue::Integer(0)]));
267 let _ = f.invoke(&[SqliteValue::Integer(1)]);
268 handle.join().unwrap().unwrap();
269 }
270}