Skip to main content

cai_query/
functions.rs

1//! Built-in SQL functions for CAI queries
2
3use crate::error::{QueryError, QueryResult};
4use chrono::{DateTime, Utc};
5use std::collections::HashMap;
6
7/// Type alias for SQL function implementation
8type SqlFunction = Box<dyn Fn(&[FunctionArg]) -> QueryResult<FunctionArg> + Send + Sync>;
9
10/// Function registry for SQL functions
11pub struct FunctionRegistry {
12    functions: HashMap<String, SqlFunction>,
13}
14
15impl Default for FunctionRegistry {
16    fn default() -> Self {
17        let mut registry = Self {
18            functions: HashMap::new(),
19        };
20
21        // Register built-in functions
22        registry.register("date_format", date_format);
23        registry.register("concat", concat);
24        registry.register("length", length);
25        registry.register("upper", upper);
26        registry.register("lower", lower);
27        registry.register("substring", substring);
28        registry.register("coalesce", coalesce);
29        registry.register("now", now);
30
31        registry
32    }
33}
34
35impl FunctionRegistry {
36    /// Register a new function
37    pub fn register<F>(&mut self, name: &str, func: F)
38    where
39        F: Fn(&[FunctionArg]) -> QueryResult<FunctionArg> + Send + Sync + 'static,
40    {
41        self.functions.insert(name.to_lowercase(), Box::new(func));
42    }
43
44    /// Call a function by name
45    pub fn call(&self, name: &str, args: &[FunctionArg]) -> QueryResult<FunctionArg> {
46        match self.functions.get(&name.to_lowercase()) {
47            Some(func) => func(args),
48            None => Err(QueryError::ParseError(format!(
49                "Unknown function: {}",
50                name
51            ))),
52        }
53    }
54
55    /// Check if a function exists
56    pub fn has_function(&self, name: &str) -> bool {
57        self.functions.contains_key(&name.to_lowercase())
58    }
59}
60
61/// Function argument value
62#[derive(Debug, Clone, PartialEq)]
63pub enum FunctionArg {
64    /// String value
65    String(String),
66    /// Integer number
67    Number(i64),
68    /// Floating-point number
69    Float(f64),
70    /// Boolean value
71    Boolean(bool),
72    /// Null value
73    Null,
74}
75
76impl From<&str> for FunctionArg {
77    fn from(s: &str) -> Self {
78        FunctionArg::String(s.to_string())
79    }
80}
81
82impl From<String> for FunctionArg {
83    fn from(s: String) -> Self {
84        FunctionArg::String(s)
85    }
86}
87
88impl From<i64> for FunctionArg {
89    fn from(n: i64) -> Self {
90        FunctionArg::Number(n)
91    }
92}
93
94impl From<f64> for FunctionArg {
95    fn from(n: f64) -> Self {
96        FunctionArg::Float(n)
97    }
98}
99
100impl From<bool> for FunctionArg {
101    fn from(b: bool) -> Self {
102        FunctionArg::Boolean(b)
103    }
104}
105
106// ============================================================================
107// Built-in Functions
108// ============================================================================
109
110/// Format a DateTime to a string representation
111///
112/// Supports formats: "iso", "date", "time", "unix", "year", "month", "day", "hour", "minute", "ymd"
113pub fn date_format(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
114    if args.len() != 2 {
115        return Err(QueryError::ParseError(
116            "date_format requires 2 arguments: timestamp and format".to_string(),
117        ));
118    }
119
120    // For now, we'll work with string timestamps
121    let timestamp_str = match &args[0] {
122        FunctionArg::String(s) => s.as_str(),
123        _ => {
124            return Err(QueryError::ParseError(
125                "date_format first argument must be a string".to_string(),
126            ))
127        }
128    };
129
130    let format = match &args[1] {
131        FunctionArg::String(s) => s.as_str(),
132        _ => {
133            return Err(QueryError::ParseError(
134                "date_format second argument must be a string".to_string(),
135            ))
136        }
137    };
138
139    // Parse ISO 8601 timestamp
140    let dt = DateTime::parse_from_rfc3339(timestamp_str)
141        .map_err(|_| QueryError::ParseError(format!("Invalid timestamp: {}", timestamp_str)))?;
142
143    let result = match format {
144        "iso" => dt.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
145        "date" => dt.format("%Y-%m-%d").to_string(),
146        "time" => dt.format("%H:%M:%S").to_string(),
147        "unix" => dt.timestamp().to_string(),
148        "year" => dt.format("%Y").to_string(),
149        "month" => dt.format("%m").to_string(),
150        "day" => dt.format("%d").to_string(),
151        "hour" => dt.format("%H").to_string(),
152        "minute" => dt.format("%M").to_string(),
153        "ymd" => dt.format("%Y%m%d").to_string(),
154        _ => {
155            return Err(QueryError::ParseError(format!(
156                "Unknown format: {}",
157                format
158            )))
159        }
160    };
161
162    Ok(FunctionArg::String(result))
163}
164
165/// Concatenate strings
166pub fn concat(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
167    if args.is_empty() {
168        return Ok(FunctionArg::String(String::new()));
169    }
170
171    let mut result = String::new();
172    for arg in args {
173        match arg {
174            FunctionArg::String(s) => result.push_str(s),
175            FunctionArg::Number(n) => result.push_str(&n.to_string()),
176            FunctionArg::Float(f) => result.push_str(&f.to_string()),
177            FunctionArg::Boolean(b) => result.push_str(&b.to_string()),
178            FunctionArg::Null => result.push_str("NULL"),
179        }
180    }
181
182    Ok(FunctionArg::String(result))
183}
184
185/// Get string length
186pub fn length(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
187    if args.len() != 1 {
188        return Err(QueryError::ParseError(
189            "length requires 1 argument".to_string(),
190        ));
191    }
192
193    let len = match &args[0] {
194        FunctionArg::String(s) => s.len(),
195        FunctionArg::Number(n) => n.to_string().len(),
196        FunctionArg::Float(f) => f.to_string().len(),
197        FunctionArg::Boolean(b) => b.to_string().len(),
198        FunctionArg::Null => 4, // "NULL"
199    };
200
201    Ok(FunctionArg::Number(len as i64))
202}
203
204/// Convert string to uppercase
205pub fn upper(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
206    if args.len() != 1 {
207        return Err(QueryError::ParseError(
208            "upper requires 1 argument".to_string(),
209        ));
210    }
211
212    let result = match &args[0] {
213        FunctionArg::String(s) => s.to_uppercase(),
214        FunctionArg::Number(n) => n.to_string().to_uppercase(),
215        FunctionArg::Float(f) => f.to_string().to_uppercase(),
216        FunctionArg::Boolean(b) => b.to_string().to_uppercase(),
217        FunctionArg::Null => String::from("NULL"),
218    };
219
220    Ok(FunctionArg::String(result))
221}
222
223/// Convert string to lowercase
224pub fn lower(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
225    if args.len() != 1 {
226        return Err(QueryError::ParseError(
227            "lower requires 1 argument".to_string(),
228        ));
229    }
230
231    let result = match &args[0] {
232        FunctionArg::String(s) => s.to_lowercase(),
233        FunctionArg::Number(n) => n.to_string().to_lowercase(),
234        FunctionArg::Float(f) => f.to_string().to_lowercase(),
235        FunctionArg::Boolean(b) => b.to_string().to_lowercase(),
236        FunctionArg::Null => String::from("null"),
237    };
238
239    Ok(FunctionArg::String(result))
240}
241
242/// Extract substring (1-indexed start position)
243pub fn substring(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
244    if args.len() < 2 || args.len() > 3 {
245        return Err(QueryError::ParseError(
246            "substring requires 2 or 3 arguments: string, start, [length]".to_string(),
247        ));
248    }
249
250    let s = match &args[0] {
251        FunctionArg::String(s) => s.as_str(),
252        _ => {
253            return Err(QueryError::ParseError(
254                "substring first argument must be a string".to_string(),
255            ))
256        }
257    };
258
259    let start = match &args[1] {
260        FunctionArg::Number(n) => *n as usize,
261        _ => {
262            return Err(QueryError::ParseError(
263                "substring second argument must be a number".to_string(),
264            ))
265        }
266    };
267
268    let result = if args.len() == 3 {
269        let length = match &args[2] {
270            FunctionArg::Number(n) => *n as usize,
271            _ => {
272                return Err(QueryError::ParseError(
273                    "substring third argument must be a number".to_string(),
274                ))
275            }
276        };
277        // Convert 1-indexed to 0-indexed
278        let start_idx = start.saturating_sub(1);
279        s.chars().skip(start_idx).take(length).collect()
280    } else {
281        // Convert 1-indexed to 0-indexed
282        let start_idx = start.saturating_sub(1);
283        s.chars().skip(start_idx).collect()
284    };
285
286    Ok(FunctionArg::String(result))
287}
288
289/// Return first non-null argument
290pub fn coalesce(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
291    for arg in args {
292        if arg != &FunctionArg::Null {
293            return Ok(arg.clone());
294        }
295    }
296    Ok(FunctionArg::Null)
297}
298
299/// Get current timestamp
300pub fn now(args: &[FunctionArg]) -> QueryResult<FunctionArg> {
301    if !args.is_empty() {
302        return Err(QueryError::ParseError(
303            "now requires no arguments".to_string(),
304        ));
305    }
306
307    let now: DateTime<Utc> = Utc::now();
308    Ok(FunctionArg::String(
309        now.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
310    ))
311}
312
313// ============================================================================
314// Tests
315// ============================================================================
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_function_registry_default() {
323        let registry = FunctionRegistry::default();
324        assert!(registry.has_function("date_format"));
325        assert!(registry.has_function("concat"));
326        assert!(registry.has_function("length"));
327        assert!(registry.has_function("upper"));
328        assert!(registry.has_function("lower"));
329        assert!(registry.has_function("substring"));
330        assert!(registry.has_function("coalesce"));
331        assert!(registry.has_function("now"));
332    }
333
334    #[test]
335    fn test_date_format() {
336        let args = vec![
337            FunctionArg::String("2024-01-15T10:30:00Z".to_string()),
338            FunctionArg::String("iso".to_string()),
339        ];
340        let result = date_format(&args).unwrap();
341        assert_eq!(
342            result,
343            FunctionArg::String("2024-01-15T10:30:00Z".to_string())
344        );
345
346        let args = vec![
347            FunctionArg::String("2024-01-15T10:30:00Z".to_string()),
348            FunctionArg::String("date".to_string()),
349        ];
350        let result = date_format(&args).unwrap();
351        assert_eq!(result, FunctionArg::String("2024-01-15".to_string()));
352    }
353
354    #[test]
355    fn test_concat() {
356        let args = vec![
357            FunctionArg::String("Hello".to_string()),
358            FunctionArg::String(" ".to_string()),
359            FunctionArg::String("World".to_string()),
360        ];
361        let result = concat(&args).unwrap();
362        assert_eq!(result, FunctionArg::String("Hello World".to_string()));
363
364        // With numbers
365        let args = vec![
366            FunctionArg::String("Count: ".to_string()),
367            FunctionArg::Number(42),
368        ];
369        let result = concat(&args).unwrap();
370        assert_eq!(result, FunctionArg::String("Count: 42".to_string()));
371    }
372
373    #[test]
374    fn test_length() {
375        let args = vec![FunctionArg::String("hello".to_string())];
376        let result = length(&args).unwrap();
377        assert_eq!(result, FunctionArg::Number(5));
378
379        let args = vec![FunctionArg::String("".to_string())];
380        let result = length(&args).unwrap();
381        assert_eq!(result, FunctionArg::Number(0));
382    }
383
384    #[test]
385    fn test_upper_lower() {
386        let args = vec![FunctionArg::String("Hello".to_string())];
387        let result = upper(&args).unwrap();
388        assert_eq!(result, FunctionArg::String("HELLO".to_string()));
389
390        let args = vec![FunctionArg::String("HELLO".to_string())];
391        let result = lower(&args).unwrap();
392        assert_eq!(result, FunctionArg::String("hello".to_string()));
393    }
394
395    #[test]
396    fn test_substring() {
397        let args = vec![
398            FunctionArg::String("hello".to_string()),
399            FunctionArg::Number(2),
400            FunctionArg::Number(3),
401        ];
402        let result = substring(&args).unwrap();
403        assert_eq!(result, FunctionArg::String("ell".to_string()));
404
405        // Without length
406        let args = vec![
407            FunctionArg::String("hello".to_string()),
408            FunctionArg::Number(2),
409        ];
410        let result = substring(&args).unwrap();
411        assert_eq!(result, FunctionArg::String("ello".to_string()));
412    }
413
414    #[test]
415    fn test_coalesce() {
416        let args = vec![
417            FunctionArg::Null,
418            FunctionArg::String("default".to_string()),
419            FunctionArg::String("other".to_string()),
420        ];
421        let result = coalesce(&args).unwrap();
422        assert_eq!(result, FunctionArg::String("default".to_string()));
423
424        // All null
425        let args = vec![FunctionArg::Null, FunctionArg::Null];
426        let result = coalesce(&args).unwrap();
427        assert_eq!(result, FunctionArg::Null);
428    }
429
430    #[test]
431    fn test_now() {
432        let result = now(&[]).unwrap();
433        match result {
434            FunctionArg::String(s) => {
435                // Should be a valid ISO 8601 timestamp
436                assert!(DateTime::parse_from_rfc3339(s.as_str()).is_ok());
437            }
438            _ => panic!("now() should return a string"),
439        }
440    }
441
442    #[test]
443    fn test_function_registry_call() {
444        let registry = FunctionRegistry::default();
445
446        // Test calling upper function
447        let args = vec![FunctionArg::String("hello".to_string())];
448        let result = registry.call("upper", &args).unwrap();
449        assert_eq!(result, FunctionArg::String("HELLO".to_string()));
450
451        // Test calling unknown function
452        let args = vec![FunctionArg::String("test".to_string())];
453        let result = registry.call("unknown", &args);
454        assert!(result.is_err());
455    }
456}