qrlew/expr/
implementation.rs

1use super::{aggregate::Aggregate, function::Function};
2use crate::data_type::{
3    function::{self, Optional},
4    DataType,
5};
6use paste::paste;
7use rand::rngs::OsRng;
8use std::sync::{Arc, Mutex};
9
10macro_rules! function_implementations {
11    ([$($nullary:ident),*], [$($unary:ident),*], [$($binary:ident),*], [$($ternary:ident),*], [$($quaternary:ident),*], $function:ident, $default:block) => {
12        paste! {
13            // A (thread local) global map
14            thread_local! {
15                static FUNCTION_IMPLEMENTATIONS: FunctionImplementations = FunctionImplementations {
16                    $([< $nullary:snake >]: Arc::new(function::[< $nullary:snake >]()),)*
17                    $([< $unary:snake >]: Arc::new(Optional::new(function::[< $unary:snake >]())),)*
18                    $([< $binary:snake >]: Arc::new(Optional::new(function::[< $binary:snake >]())),)*
19                    $([< $ternary:snake >]: Arc::new(Optional::new(function::[< $ternary:snake >]())),)*
20                    $([< $quaternary:snake >]: Arc::new(Optional::new(function::[< $quaternary:snake >]())),)*
21                };
22            }
23
24            /// A struct containing all implementations
25            struct FunctionImplementations {
26                $(pub [< $nullary:snake >]: Arc<dyn function::Function>,)*
27                $(pub [< $unary:snake >]: Arc<dyn function::Function>,)*
28                $(pub [< $binary:snake >]: Arc<dyn function::Function>,)*
29                $(pub [< $ternary:snake >]: Arc<dyn function::Function>,)*
30                $(pub [< $quaternary:snake >]: Arc<dyn function::Function>,)*
31            }
32
33            /// The object to access implementations
34            pub fn function(function: Function) -> Arc<dyn function::Function> {
35                match function {
36                    $(Function::$nullary => FUNCTION_IMPLEMENTATIONS.with(|impls| impls.[< $nullary:snake >].clone()),)*
37                    $(Function::$unary => FUNCTION_IMPLEMENTATIONS.with(|impls| impls.[< $unary:snake >].clone()),)*
38                    $(Function::$binary => FUNCTION_IMPLEMENTATIONS.with(|impls| impls.[< $binary:snake >].clone()),)*
39                    $(Function::$ternary => FUNCTION_IMPLEMENTATIONS.with(|impls| impls.[< $ternary:snake >].clone()),)*
40                    $(Function::$quaternary => FUNCTION_IMPLEMENTATIONS.with(|impls| impls.[< $quaternary:snake >].clone()),)*
41                    $function => $default
42                }
43            }
44        }
45    };
46}
47
48// All functions:
49// Nullary: Pi, Newid, CurrentDate, CurrentTime, CurrentTimestamp
50// Unary: Opposite, Not, Exp, Ln, Abs, Sin, Cos, CharLength, Lower, Upper, Md5, Ceil, Floor, Sign, Dayname, Quarter, Date, UnixTimestamp, IsNull
51// Binary: Plus, Minus, Multiply, Divide, Modulo, StringConcat, Gt, Lt, GtEq, LtEq, Eq, NotEq, And, Or, Xor, BitwiseOr, BitwiseAnd, BitwiseXor, Position, Concat, Greatest, Least, Round, Trunc, DateFormat, FromUnixtime, Like, Ilike, Choose, IsBool
52// Ternary: Case, Position, DateTimeDiff
53// Quaternary: RegexExtract
54// Nary: Concat
55function_implementations!(
56    [Pi, Newid, CurrentDate, CurrentTime, CurrentTimestamp],
57    [
58        Opposite,
59        Not,
60        Exp,
61        Ln,
62        Log,
63        Abs,
64        Sin,
65        Cos,
66        Sqrt,
67        Md5,
68        Ceil,
69        Floor,
70        Sign,
71        Unhex,
72        Dayname,
73        Quarter,
74        Date,
75        UnixTimestamp,
76        IsNull
77    ],
78    [
79        Plus,
80        Minus,
81        Multiply,
82        Divide,
83        Modulo,
84        StringConcat,
85        Gt,
86        Lt,
87        GtEq,
88        LtEq,
89        Eq,
90        NotEq,
91        And,
92        Or,
93        Xor,
94        BitwiseOr,
95        BitwiseAnd,
96        BitwiseXor,
97        Pow,
98        CharLength,
99        Lower,
100        Upper,
101        InList,
102        Least,
103        Greatest,
104        Rtrim,
105        Ltrim,
106        Substr,
107        Round,
108        Trunc,
109        RegexpContains,
110        Encode,
111        Decode,
112        ExtractEpoch,
113        ExtractYear,
114        ExtractMonth,
115        ExtractDay,
116        ExtractHour,
117        ExtractMinute,
118        ExtractSecond,
119        ExtractMicrosecond,
120        ExtractMillisecond,
121        ExtractDow,
122        ExtractWeek,
123        DateFormat,
124        FromUnixtime,
125        Like,
126        Ilike,
127        IsBool
128    ],
129    [Case, Position, SubstrWithSize, RegexpReplace, DatetimeDiff],
130    [RegexpExtract],
131    x,
132    {
133        match x {
134            Function::CastAsText => Arc::new(function::cast(DataType::text())),
135            Function::CastAsInteger => Arc::new(Optional::new(function::cast(DataType::integer()))),
136            Function::CastAsFloat => Arc::new(Optional::new(function::cast(DataType::float()))),
137            Function::CastAsBoolean => Arc::new(Optional::new(function::cast(DataType::boolean()))),
138            Function::CastAsDateTime => {
139                Arc::new(Optional::new(function::cast(DataType::date_time())))
140            }
141            Function::CastAsDate => Arc::new(Optional::new(function::cast(DataType::date()))),
142            Function::CastAsTime => Arc::new(Optional::new(function::cast(DataType::time()))),
143            Function::Concat(n) => Arc::new(function::concat(n)),
144            Function::Random(_n) => Arc::new(function::random(Mutex::new(OsRng))), //TODO change this initialization
145            Function::Coalesce => Arc::new(function::coalesce()),
146            _ => unreachable!(),
147        }
148    }
149);
150
151macro_rules! aggregate_implementations {
152    ([$($implementation:ident),*], $aggregate:ident, $default:block) => {
153        paste! {
154            // A (thread local) global map
155            thread_local! {
156                static AGGREGATE_IMPLEMENTATIONS: AggregateImplementations = AggregateImplementations {
157                    $([< $implementation:snake >]: Arc::new(Optional::new(function::[< $implementation:snake >]())),)*
158                };
159            }
160
161            /// A struct containing all implementations
162            struct AggregateImplementations {
163                $(pub [< $implementation:snake >]: Arc<dyn function::Function>,)*
164            }
165
166            /// The object to access implementations
167            pub fn aggregate(aggregate: Aggregate) -> Arc<dyn function::Function> {
168                match aggregate {
169                    $(Aggregate::$implementation => AGGREGATE_IMPLEMENTATIONS.with(|impls| impls.[< $implementation:snake >].clone()),)*
170                    $aggregate => $default
171                }
172            }
173        }
174    };
175}
176
177aggregate_implementations!(
178    [
179        Min,
180        Max,
181        Median,
182        NUnique,
183        First,
184        Last,
185        Mean,
186        List,
187        Count,
188        Sum,
189        AggGroups,
190        Std,
191        Var,
192        MeanDistinct,
193        CountDistinct,
194        SumDistinct,
195        StdDistinct,
196        VarDistinct
197    ],
198    x,
199    {
200        match x {
201            Aggregate::Quantile(p) => Arc::new(function::quantile(p)),
202            Aggregate::Quantiles(p) => Arc::new(function::quantiles(p.iter().cloned().collect())),
203            _ => unreachable!(),
204        }
205    }
206);
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn test_implementations() {
214        println!("exp = {}", function(Function::Exp));
215        println!("plus = {}", function(Function::Plus));
216        println!(
217            "plus.super_image({}) = {}",
218            &(DataType::float() & DataType::float()),
219            function(Function::Plus)
220                .super_image(&(DataType::float() & DataType::float()))
221                .unwrap()
222        );
223        println!("count = {}", aggregate(Aggregate::Count));
224        println!("quantile = {}", aggregate(Aggregate::Quantile(5.0)));
225    }
226}