1use crate::error::{QueryError, QueryResult};
4use chrono::{DateTime, Utc};
5use std::collections::HashMap;
6
7type SqlFunction = Box<dyn Fn(&[FunctionArg]) -> QueryResult<FunctionArg> + Send + Sync>;
9
10pub struct FunctionRegistry {
12 functions: HashMap<String, SqlFunction>,
13}
14
15impl Default for FunctionRegistry {
16 fn default() -> Self {
17 let mut registry = Self {
18 functions: HashMap::new(),
19 };
20
21 registry.register("date_format", date_format);
23 registry.register("concat", concat);
24 registry.register("length", length);
25 registry.register("upper", upper);
26 registry.register("lower", lower);
27 registry.register("substring", substring);
28 registry.register("coalesce", coalesce);
29 registry.register("now", now);
30
31 registry
32 }
33}
34
35impl FunctionRegistry {
36 pub fn register<F>(&mut self, name: &str, func: F)
38 where
39 F: Fn(&[FunctionArg]) -> QueryResult<FunctionArg> + Send + Sync + 'static,
40 {
41 self.functions.insert(name.to_lowercase(), Box::new(func));
42 }
43
44 pub fn call(&self, name: &str, args: &[FunctionArg]) -> QueryResult<FunctionArg> {
46 match self.functions.get(&name.to_lowercase()) {
47 Some(func) => func(args),
48 None => Err(QueryError::ParseError(format!(
49 "Unknown function: {}",
50 name
51 ))),
52 }
53 }
54
55 pub fn has_function(&self, name: &str) -> bool {
57 self.functions.contains_key(&name.to_lowercase())
58 }
59}
60
61#[derive(Debug, Clone, PartialEq)]
63pub enum FunctionArg {
64 String(String),
66 Number(i64),
68 Float(f64),
70 Boolean(bool),
72 Null,
74}
75
76impl From<&str> for FunctionArg {
77 fn from(s: &str) -> Self {
78 FunctionArg::String(s.to_string())
79 }
80}
81
82impl From<String> for FunctionArg {
83 fn from(s: String) -> Self {
84 FunctionArg::String(s)
85 }
86}
87
88impl From<i64> for FunctionArg {
89 fn from(n: i64) -> Self {
90 FunctionArg::Number(n)
91 }
92}
93
94impl From<f64> for FunctionArg {
95 fn from(n: f64) -> Self {
96 FunctionArg::Float(n)
97 }
98}
99
100impl From<bool> for FunctionArg {
101 fn from(b: bool) -> Self {
102 FunctionArg::Boolean(b)
103 }
104}
105
106pub fn date_format(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
114 if args.len() != 2 {
115 return Err(QueryError::ParseError(
116 "date_format requires 2 arguments: timestamp and format".to_string(),
117 ));
118 }
119
120 let timestamp_str = match &args[0] {
122 FunctionArg::String(s) => s.as_str(),
123 _ => {
124 return Err(QueryError::ParseError(
125 "date_format first argument must be a string".to_string(),
126 ))
127 }
128 };
129
130 let format = match &args[1] {
131 FunctionArg::String(s) => s.as_str(),
132 _ => {
133 return Err(QueryError::ParseError(
134 "date_format second argument must be a string".to_string(),
135 ))
136 }
137 };
138
139 let dt = DateTime::parse_from_rfc3339(timestamp_str)
141 .map_err(|_| QueryError::ParseError(format!("Invalid timestamp: {}", timestamp_str)))?;
142
143 let result = match format {
144 "iso" => dt.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
145 "date" => dt.format("%Y-%m-%d").to_string(),
146 "time" => dt.format("%H:%M:%S").to_string(),
147 "unix" => dt.timestamp().to_string(),
148 "year" => dt.format("%Y").to_string(),
149 "month" => dt.format("%m").to_string(),
150 "day" => dt.format("%d").to_string(),
151 "hour" => dt.format("%H").to_string(),
152 "minute" => dt.format("%M").to_string(),
153 "ymd" => dt.format("%Y%m%d").to_string(),
154 _ => {
155 return Err(QueryError::ParseError(format!(
156 "Unknown format: {}",
157 format
158 )))
159 }
160 };
161
162 Ok(FunctionArg::String(result))
163}
164
165pub fn concat(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
167 if args.is_empty() {
168 return Ok(FunctionArg::String(String::new()));
169 }
170
171 let mut result = String::new();
172 for arg in args {
173 match arg {
174 FunctionArg::String(s) => result.push_str(s),
175 FunctionArg::Number(n) => result.push_str(&n.to_string()),
176 FunctionArg::Float(f) => result.push_str(&f.to_string()),
177 FunctionArg::Boolean(b) => result.push_str(&b.to_string()),
178 FunctionArg::Null => result.push_str("NULL"),
179 }
180 }
181
182 Ok(FunctionArg::String(result))
183}
184
185pub fn length(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
187 if args.len() != 1 {
188 return Err(QueryError::ParseError(
189 "length requires 1 argument".to_string(),
190 ));
191 }
192
193 let len = match &args[0] {
194 FunctionArg::String(s) => s.len(),
195 FunctionArg::Number(n) => n.to_string().len(),
196 FunctionArg::Float(f) => f.to_string().len(),
197 FunctionArg::Boolean(b) => b.to_string().len(),
198 FunctionArg::Null => 4, };
200
201 Ok(FunctionArg::Number(len as i64))
202}
203
204pub fn upper(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
206 if args.len() != 1 {
207 return Err(QueryError::ParseError(
208 "upper requires 1 argument".to_string(),
209 ));
210 }
211
212 let result = match &args[0] {
213 FunctionArg::String(s) => s.to_uppercase(),
214 FunctionArg::Number(n) => n.to_string().to_uppercase(),
215 FunctionArg::Float(f) => f.to_string().to_uppercase(),
216 FunctionArg::Boolean(b) => b.to_string().to_uppercase(),
217 FunctionArg::Null => String::from("NULL"),
218 };
219
220 Ok(FunctionArg::String(result))
221}
222
223pub fn lower(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
225 if args.len() != 1 {
226 return Err(QueryError::ParseError(
227 "lower requires 1 argument".to_string(),
228 ));
229 }
230
231 let result = match &args[0] {
232 FunctionArg::String(s) => s.to_lowercase(),
233 FunctionArg::Number(n) => n.to_string().to_lowercase(),
234 FunctionArg::Float(f) => f.to_string().to_lowercase(),
235 FunctionArg::Boolean(b) => b.to_string().to_lowercase(),
236 FunctionArg::Null => String::from("null"),
237 };
238
239 Ok(FunctionArg::String(result))
240}
241
242pub fn substring(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
244 if args.len() < 2 || args.len() > 3 {
245 return Err(QueryError::ParseError(
246 "substring requires 2 or 3 arguments: string, start, [length]".to_string(),
247 ));
248 }
249
250 let s = match &args[0] {
251 FunctionArg::String(s) => s.as_str(),
252 _ => {
253 return Err(QueryError::ParseError(
254 "substring first argument must be a string".to_string(),
255 ))
256 }
257 };
258
259 let start = match &args[1] {
260 FunctionArg::Number(n) => *n as usize,
261 _ => {
262 return Err(QueryError::ParseError(
263 "substring second argument must be a number".to_string(),
264 ))
265 }
266 };
267
268 let result = if args.len() == 3 {
269 let length = match &args[2] {
270 FunctionArg::Number(n) => *n as usize,
271 _ => {
272 return Err(QueryError::ParseError(
273 "substring third argument must be a number".to_string(),
274 ))
275 }
276 };
277 let start_idx = start.saturating_sub(1);
279 s.chars().skip(start_idx).take(length).collect()
280 } else {
281 let start_idx = start.saturating_sub(1);
283 s.chars().skip(start_idx).collect()
284 };
285
286 Ok(FunctionArg::String(result))
287}
288
289pub fn coalesce(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
291 for arg in args {
292 if arg != &FunctionArg::Null {
293 return Ok(arg.clone());
294 }
295 }
296 Ok(FunctionArg::Null)
297}
298
299pub fn now(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
301 if !args.is_empty() {
302 return Err(QueryError::ParseError(
303 "now requires no arguments".to_string(),
304 ));
305 }
306
307 let now: DateTime<Utc> = Utc::now();
308 Ok(FunctionArg::String(
309 now.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
310 ))
311}
312
313#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn test_function_registry_default() {
323 let registry = FunctionRegistry::default();
324 assert!(registry.has_function("date_format"));
325 assert!(registry.has_function("concat"));
326 assert!(registry.has_function("length"));
327 assert!(registry.has_function("upper"));
328 assert!(registry.has_function("lower"));
329 assert!(registry.has_function("substring"));
330 assert!(registry.has_function("coalesce"));
331 assert!(registry.has_function("now"));
332 }
333
334 #[test]
335 fn test_date_format() {
336 let args = vec![
337 FunctionArg::String("2024-01-15T10:30:00Z".to_string()),
338 FunctionArg::String("iso".to_string()),
339 ];
340 let result = date_format(&args).unwrap();
341 assert_eq!(
342 result,
343 FunctionArg::String("2024-01-15T10:30:00Z".to_string())
344 );
345
346 let args = vec![
347 FunctionArg::String("2024-01-15T10:30:00Z".to_string()),
348 FunctionArg::String("date".to_string()),
349 ];
350 let result = date_format(&args).unwrap();
351 assert_eq!(result, FunctionArg::String("2024-01-15".to_string()));
352 }
353
354 #[test]
355 fn test_concat() {
356 let args = vec![
357 FunctionArg::String("Hello".to_string()),
358 FunctionArg::String(" ".to_string()),
359 FunctionArg::String("World".to_string()),
360 ];
361 let result = concat(&args).unwrap();
362 assert_eq!(result, FunctionArg::String("Hello World".to_string()));
363
364 let args = vec![
366 FunctionArg::String("Count: ".to_string()),
367 FunctionArg::Number(42),
368 ];
369 let result = concat(&args).unwrap();
370 assert_eq!(result, FunctionArg::String("Count: 42".to_string()));
371 }
372
373 #[test]
374 fn test_length() {
375 let args = vec![FunctionArg::String("hello".to_string())];
376 let result = length(&args).unwrap();
377 assert_eq!(result, FunctionArg::Number(5));
378
379 let args = vec![FunctionArg::String("".to_string())];
380 let result = length(&args).unwrap();
381 assert_eq!(result, FunctionArg::Number(0));
382 }
383
384 #[test]
385 fn test_upper_lower() {
386 let args = vec![FunctionArg::String("Hello".to_string())];
387 let result = upper(&args).unwrap();
388 assert_eq!(result, FunctionArg::String("HELLO".to_string()));
389
390 let args = vec![FunctionArg::String("HELLO".to_string())];
391 let result = lower(&args).unwrap();
392 assert_eq!(result, FunctionArg::String("hello".to_string()));
393 }
394
395 #[test]
396 fn test_substring() {
397 let args = vec![
398 FunctionArg::String("hello".to_string()),
399 FunctionArg::Number(2),
400 FunctionArg::Number(3),
401 ];
402 let result = substring(&args).unwrap();
403 assert_eq!(result, FunctionArg::String("ell".to_string()));
404
405 let args = vec![
407 FunctionArg::String("hello".to_string()),
408 FunctionArg::Number(2),
409 ];
410 let result = substring(&args).unwrap();
411 assert_eq!(result, FunctionArg::String("ello".to_string()));
412 }
413
414 #[test]
415 fn test_coalesce() {
416 let args = vec![
417 FunctionArg::Null,
418 FunctionArg::String("default".to_string()),
419 FunctionArg::String("other".to_string()),
420 ];
421 let result = coalesce(&args).unwrap();
422 assert_eq!(result, FunctionArg::String("default".to_string()));
423
424 let args = vec![FunctionArg::Null, FunctionArg::Null];
426 let result = coalesce(&args).unwrap();
427 assert_eq!(result, FunctionArg::Null);
428 }
429
430 #[test]
431 fn test_now() {
432 let result = now(&[]).unwrap();
433 match result {
434 FunctionArg::String(s) => {
435 assert!(DateTime::parse_from_rfc3339(s.as_str()).is_ok());
437 }
438 _ => panic!("now() should return a string"),
439 }
440 }
441
442 #[test]
443 fn test_function_registry_call() {
444 let registry = FunctionRegistry::default();
445
446 let args = vec![FunctionArg::String("hello".to_string())];
448 let result = registry.call("upper", &args).unwrap();
449 assert_eq!(result, FunctionArg::String("HELLO".to_string()));
450
451 let args = vec![FunctionArg::String("test".to_string())];
453 let result = registry.call("unknown", &args);
454 assert!(result.is_err());
455 }
456}