1use super::{aggregate::Aggregate, function::Function};
2use crate::data_type::{
3 function::{self, Optional},
4 DataType,
5};
6use paste::paste;
7use rand::rngs::OsRng;
8use std::sync::{Arc, Mutex};
9
10macro_rules! function_implementations {
11 ([$($nullary:ident),*], [$($unary:ident),*], [$($binary:ident),*], [$($ternary:ident),*], [$($quaternary:ident),*], $function:ident, $default:block) => {
12 paste! {
13 thread_local! {
15 static FUNCTION_IMPLEMENTATIONS: FunctionImplementations = FunctionImplementations {
16 $([< $nullary:snake >]: Arc::new(function::[< $nullary:snake >]()),)*
17 $([< $unary:snake >]: Arc::new(Optional::new(function::[< $unary:snake >]())),)*
18 $([< $binary:snake >]: Arc::new(Optional::new(function::[< $binary:snake >]())),)*
19 $([< $ternary:snake >]: Arc::new(Optional::new(function::[< $ternary:snake >]())),)*
20 $([< $quaternary:snake >]: Arc::new(Optional::new(function::[< $quaternary:snake >]())),)*
21 };
22 }
23
24 struct FunctionImplementations {
26 $(pub [< $nullary:snake >]: Arc<dyn function::Function>,)*
27 $(pub [< $unary:snake >]: Arc<dyn function::Function>,)*
28 $(pub [< $binary:snake >]: Arc<dyn function::Function>,)*
29 $(pub [< $ternary:snake >]: Arc<dyn function::Function>,)*
30 $(pub [< $quaternary:snake >]: Arc<dyn function::Function>,)*
31 }
32
33 pub fn function(function: Function) -> Arc<dyn function::Function> {
35 match function {
36 $(Function::$nullary => FUNCTION_IMPLEMENTATIONS.with(|impls| impls.[< $nullary:snake >].clone()),)*
37 $(Function::$unary => FUNCTION_IMPLEMENTATIONS.with(|impls| impls.[< $unary:snake >].clone()),)*
38 $(Function::$binary => FUNCTION_IMPLEMENTATIONS.with(|impls| impls.[< $binary:snake >].clone()),)*
39 $(Function::$ternary => FUNCTION_IMPLEMENTATIONS.with(|impls| impls.[< $ternary:snake >].clone()),)*
40 $(Function::$quaternary => FUNCTION_IMPLEMENTATIONS.with(|impls| impls.[< $quaternary:snake >].clone()),)*
41 $function => $default
42 }
43 }
44 }
45 };
46}
47
48function_implementations!(
56 [Pi, Newid, CurrentDate, CurrentTime, CurrentTimestamp],
57 [
58 Opposite,
59 Not,
60 Exp,
61 Ln,
62 Log,
63 Abs,
64 Sin,
65 Cos,
66 Sqrt,
67 Md5,
68 Ceil,
69 Floor,
70 Sign,
71 Unhex,
72 Dayname,
73 Quarter,
74 Date,
75 UnixTimestamp,
76 IsNull
77 ],
78 [
79 Plus,
80 Minus,
81 Multiply,
82 Divide,
83 Modulo,
84 StringConcat,
85 Gt,
86 Lt,
87 GtEq,
88 LtEq,
89 Eq,
90 NotEq,
91 And,
92 Or,
93 Xor,
94 BitwiseOr,
95 BitwiseAnd,
96 BitwiseXor,
97 Pow,
98 CharLength,
99 Lower,
100 Upper,
101 InList,
102 Least,
103 Greatest,
104 Rtrim,
105 Ltrim,
106 Substr,
107 Round,
108 Trunc,
109 RegexpContains,
110 Encode,
111 Decode,
112 ExtractEpoch,
113 ExtractYear,
114 ExtractMonth,
115 ExtractDay,
116 ExtractHour,
117 ExtractMinute,
118 ExtractSecond,
119 ExtractMicrosecond,
120 ExtractMillisecond,
121 ExtractDow,
122 ExtractWeek,
123 DateFormat,
124 FromUnixtime,
125 Like,
126 Ilike,
127 IsBool
128 ],
129 [Case, Position, SubstrWithSize, RegexpReplace, DatetimeDiff],
130 [RegexpExtract],
131 x,
132 {
133 match x {
134 Function::CastAsText => Arc::new(function::cast(DataType::text())),
135 Function::CastAsInteger => Arc::new(Optional::new(function::cast(DataType::integer()))),
136 Function::CastAsFloat => Arc::new(Optional::new(function::cast(DataType::float()))),
137 Function::CastAsBoolean => Arc::new(Optional::new(function::cast(DataType::boolean()))),
138 Function::CastAsDateTime => {
139 Arc::new(Optional::new(function::cast(DataType::date_time())))
140 }
141 Function::CastAsDate => Arc::new(Optional::new(function::cast(DataType::date()))),
142 Function::CastAsTime => Arc::new(Optional::new(function::cast(DataType::time()))),
143 Function::Concat(n) => Arc::new(function::concat(n)),
144 Function::Random(_n) => Arc::new(function::random(Mutex::new(OsRng))), Function::Coalesce => Arc::new(function::coalesce()),
146 _ => unreachable!(),
147 }
148 }
149);
150
151macro_rules! aggregate_implementations {
152 ([$($implementation:ident),*], $aggregate:ident, $default:block) => {
153 paste! {
154 thread_local! {
156 static AGGREGATE_IMPLEMENTATIONS: AggregateImplementations = AggregateImplementations {
157 $([< $implementation:snake >]: Arc::new(Optional::new(function::[< $implementation:snake >]())),)*
158 };
159 }
160
161 struct AggregateImplementations {
163 $(pub [< $implementation:snake >]: Arc<dyn function::Function>,)*
164 }
165
166 pub fn aggregate(aggregate: Aggregate) -> Arc<dyn function::Function> {
168 match aggregate {
169 $(Aggregate::$implementation => AGGREGATE_IMPLEMENTATIONS.with(|impls| impls.[< $implementation:snake >].clone()),)*
170 $aggregate => $default
171 }
172 }
173 }
174 };
175}
176
177aggregate_implementations!(
178 [
179 Min,
180 Max,
181 Median,
182 NUnique,
183 First,
184 Last,
185 Mean,
186 List,
187 Count,
188 Sum,
189 AggGroups,
190 Std,
191 Var,
192 MeanDistinct,
193 CountDistinct,
194 SumDistinct,
195 StdDistinct,
196 VarDistinct
197 ],
198 x,
199 {
200 match x {
201 Aggregate::Quantile(p) => Arc::new(function::quantile(p)),
202 Aggregate::Quantiles(p) => Arc::new(function::quantiles(p.iter().cloned().collect())),
203 _ => unreachable!(),
204 }
205 }
206);
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn test_implementations() {
214 println!("exp = {}", function(Function::Exp));
215 println!("plus = {}", function(Function::Plus));
216 println!(
217 "plus.super_image({}) = {}",
218 &(DataType::float() & DataType::float()),
219 function(Function::Plus)
220 .super_image(&(DataType::float() & DataType::float()))
221 .unwrap()
222 );
223 println!("count = {}", aggregate(Aggregate::Count));
224 println!("quantile = {}", aggregate(Aggregate::Quantile(5.0)));
225 }
226}