dply/
signatures.rs

1// Copyright (C) 2023 Vince Vasta
2// SPDX-License-Identifier: Apache-2.0
3/// A function signature arguments.
4use std::collections::HashMap;
5use std::sync::OnceLock;
6
7use crate::fuzzy;
8
9pub type SignaturesMap = HashMap<&'static str, Args>;
10
11pub fn functions() -> &'static SignaturesMap {
12    static SIGNATURES: OnceLock<SignaturesMap> = OnceLock::new();
13
14    SIGNATURES.get_or_init(|| {
15        let mut signatures = HashMap::new();
16
17        def_arrange(&mut signatures);
18        def_config(&mut signatures);
19        def_count(&mut signatures);
20        def_csv(&mut signatures);
21        def_distinct(&mut signatures);
22        def_filter(&mut signatures);
23        def_glimpse(&mut signatures);
24        def_group_by(&mut signatures);
25        def_head(&mut signatures);
26        def_joins(&mut signatures);
27        def_json(&mut signatures);
28        def_mutate(&mut signatures);
29        def_parquet(&mut signatures);
30        def_relocate(&mut signatures);
31        def_rename(&mut signatures);
32        def_show(&mut signatures);
33        def_select(&mut signatures);
34        def_summarize(&mut signatures);
35        def_unnest(&mut signatures);
36
37        signatures
38    })
39}
40
41pub fn completions(pattern: &str) -> Vec<String> {
42    static NAMES: OnceLock<Vec<String>> = OnceLock::new();
43
44    let names = NAMES.get_or_init(|| {
45        let mut names = Vec::with_capacity(1024);
46
47        for (name, args) in functions() {
48            let name = if let Args::None = args {
49                format!("{name}()")
50            } else if has_string_arg(name) {
51                format!("{name}(\"")
52            } else {
53                format!("{name}(")
54            };
55
56            names.push(name);
57            names.extend(args.names());
58        }
59
60        names.push("true".to_string());
61        names.push("false".to_string());
62
63        names.sort();
64        names.dedup();
65
66        names
67    });
68
69    let matcher = fuzzy::Matcher::new(pattern);
70
71    names
72        .iter()
73        .filter(|s| matcher.is_match(s))
74        .map(|s| s.to_string())
75        .collect()
76}
77
78fn has_string_arg(name: &str) -> bool {
79    // We don't include "contains" as the one used in filter doesn't take a
80    // string parameter (e.g. filter(contains(name, "john"))).
81    matches!(
82        name,
83        "parquet" | "csv" | "json" | "starts_with" | "ends_with"
84    )
85}
86
87#[derive(Debug, Clone)]
88pub enum Args {
89    /// No arguments.
90    None,
91    /// Zero or one arguments.
92    NoneOrOne(ArgType),
93    /// Zero or more arguments.
94    ZeroOrMore(ArgType),
95    /// One or more arguments.
96    OneOrMore(ArgType),
97    /// One argument of the first type and zero or more arguments of the second.
98    OneThenMore(ArgType, ArgType),
99    /// A function with a fixed number of arguments.
100    Ordered(Vec<ArgType>),
101}
102
103impl Args {
104    /// Extracts all the function and variable names in this arguments.
105    fn names(&self) -> Vec<String> {
106        let mut names = Vec::new();
107
108        match self {
109            Args::NoneOrOne(arg) => names.extend(arg.names()),
110            Args::ZeroOrMore(arg) => names.extend(arg.names()),
111            Args::OneOrMore(arg) => names.extend(arg.names()),
112            Args::OneThenMore(first, rest) => {
113                names.extend(first.names());
114                names.extend(rest.names());
115            }
116            Args::Ordered(args) => {
117                for arg in args {
118                    names.extend(arg.names());
119                }
120            }
121            _ => {}
122        }
123
124        names.sort();
125        names.dedup();
126        names
127    }
128}
129
130/// Function argument type.
131#[derive(Debug, Clone)]
132pub enum ArgType {
133    /// An arithmetich expression.
134    Arith(Box<ArgType>),
135    /// An assign expression
136    Assign(Box<ArgType>, Box<ArgType>),
137    /// A bool type.
138    Bool,
139    /// A compare expression.
140    Compare(Box<ArgType>, Box<ArgType>),
141    /// An equality expression.
142    Eq(Box<ArgType>, Box<ArgType>),
143    /// A function call expression.
144    Function(&'static str, Box<Args>),
145    /// An identifier expression.
146    Identifier,
147    /// A logical expression.
148    Logical(Box<ArgType>),
149    /// A named identifier.
150    Named(&'static str),
151    /// A negation expression.
152    Negate(Box<ArgType>),
153    /// A number.
154    Number,
155    /// A multi type argument.
156    OneOf(Vec<ArgType>),
157    /// A string argument.
158    String,
159}
160
161impl ArgType {
162    /// Creates an assignment type.
163    fn assign(lhs: ArgType, rhs: ArgType) -> Self {
164        Self::Assign(lhs.into(), rhs.into())
165    }
166
167    /// Creates an arithmetic type (+, *, -, /).
168    fn arith(arg: ArgType) -> Self {
169        Self::Arith(arg.into())
170    }
171
172    /// Creates a comparison type (<, >, !=, ==, <=, >=).
173    fn compare(lhs: ArgType, rhs: ArgType) -> Self {
174        Self::Compare(lhs.into(), rhs.into())
175    }
176
177    /// Creates an equality type.
178    fn eq(lhs: ArgType, rhs: ArgType) -> Self {
179        Self::Eq(lhs.into(), rhs.into())
180    }
181
182    /// Creates a logical type (&, |).
183    fn logical(arg: ArgType) -> Self {
184        Self::Logical(arg.into())
185    }
186
187    /// Creates a not type.
188    fn negate(arg: ArgType) -> Self {
189        Self::Negate(arg.into())
190    }
191
192    /// Creates a function ßtype.
193    fn function(name: &'static str, args: Args) -> Self {
194        ArgType::Function(name, args.into())
195    }
196
197    /// Extracts all named types.
198    fn names(&self) -> Vec<String> {
199        let mut names = Vec::new();
200
201        match self {
202            ArgType::Arith(arg) => names.extend(arg.names()),
203            ArgType::Assign(lhs, rhs) => {
204                names.extend(lhs.names());
205                names.extend(rhs.names());
206            }
207            ArgType::Compare(lhs, rhs) => {
208                names.extend(lhs.names());
209                names.extend(rhs.names());
210            }
211            ArgType::Eq(lhs, rhs) => {
212                names.extend(lhs.names());
213                names.extend(rhs.names());
214            }
215            ArgType::Function(name, args) => {
216                let name = if let Args::None = args.as_ref() {
217                    format!("{name}()")
218                } else {
219                    format!("{name}(")
220                };
221
222                names.push(name);
223                names.extend(args.names());
224            }
225            ArgType::Logical(arg) => names.extend(arg.names()),
226            ArgType::Named(name) => names.push(name.to_string()),
227            ArgType::Negate(arg) => names.extend(arg.names()),
228            ArgType::OneOf(args) => {
229                for arg in args {
230                    names.extend(arg.names());
231                }
232            }
233            _ => {}
234        }
235
236        names
237    }
238}
239
240fn def_arrange(signatures: &mut SignaturesMap) {
241    signatures.insert(
242        "arrange",
243        Args::OneOrMore(ArgType::OneOf(vec![
244            ArgType::Identifier,
245            ArgType::function("desc", Args::Ordered(vec![ArgType::Identifier])),
246        ])),
247    );
248}
249
250fn def_config(signatures: &mut SignaturesMap) {
251    signatures.insert(
252        "config",
253        Args::ZeroOrMore(ArgType::OneOf(vec![
254            ArgType::assign(ArgType::Named("max_columns"), ArgType::Number),
255            ArgType::assign(ArgType::Named("max_column_width"), ArgType::Number),
256            ArgType::assign(ArgType::Named("max_table_width"), ArgType::Number),
257        ])),
258    );
259}
260
261fn def_count(signatures: &mut SignaturesMap) {
262    signatures.insert(
263        "count",
264        Args::ZeroOrMore(ArgType::OneOf(vec![
265            ArgType::Identifier,
266            ArgType::assign(ArgType::Named("sort"), ArgType::Bool),
267        ])),
268    );
269}
270
271fn def_csv(signatures: &mut SignaturesMap) {
272    signatures.insert(
273        "csv",
274        Args::OneThenMore(
275            ArgType::String,
276            ArgType::assign(ArgType::Named("overwrite"), ArgType::Bool),
277        ),
278    );
279}
280
281fn def_distinct(signatures: &mut SignaturesMap) {
282    signatures.insert("distinct", Args::OneOrMore(ArgType::Identifier));
283}
284
285fn def_filter(signatures: &mut SignaturesMap) {
286    let compare_args = ArgType::compare(
287        ArgType::Identifier,
288        ArgType::OneOf(vec![
289            ArgType::Identifier,
290            ArgType::Number,
291            ArgType::String,
292            ArgType::Bool,
293            ArgType::function("dt", Args::Ordered(vec![ArgType::String])),
294        ]),
295    );
296
297    let contains_fn = ArgType::function(
298        "contains",
299        Args::Ordered(vec![
300            ArgType::Identifier,
301            ArgType::OneOf(vec![ArgType::String, ArgType::Number]),
302        ]),
303    );
304
305    let is_null_fn = ArgType::function("is_null", Args::Ordered(vec![ArgType::Identifier]));
306
307    let predicates = ArgType::OneOf(vec![
308        contains_fn.clone(),
309        ArgType::negate(contains_fn),
310        is_null_fn.clone(),
311        ArgType::negate(is_null_fn),
312    ]);
313
314    let filter_arg = ArgType::OneOf(vec![compare_args, predicates]);
315
316    signatures.insert(
317        "filter",
318        Args::OneOrMore(ArgType::OneOf(vec![
319            filter_arg.clone(),
320            ArgType::logical(filter_arg),
321        ])),
322    );
323}
324
325fn def_glimpse(signatures: &mut SignaturesMap) {
326    signatures.insert("glimpse", Args::None);
327}
328
329fn def_group_by(signatures: &mut SignaturesMap) {
330    signatures.insert("group_by", Args::OneOrMore(ArgType::Identifier));
331}
332
333fn def_head(signatures: &mut SignaturesMap) {
334    signatures.insert("head", Args::NoneOrOne(ArgType::Number));
335}
336
337fn def_joins(signatures: &mut SignaturesMap) {
338    let args = Args::OneThenMore(
339        ArgType::Identifier,
340        ArgType::eq(ArgType::Identifier, ArgType::Identifier),
341    );
342
343    signatures.insert("anti_join", args.clone());
344    signatures.insert("cross_join", args.clone());
345    signatures.insert("inner_join", args.clone());
346    signatures.insert("left_join", args.clone());
347    signatures.insert("outer_join", args);
348}
349
350fn def_json(signatures: &mut SignaturesMap) {
351    signatures.insert(
352        "json",
353        Args::OneThenMore(
354            ArgType::String,
355            ArgType::OneOf(vec![
356                ArgType::assign(ArgType::Named("overwrite"), ArgType::Bool),
357                ArgType::assign(ArgType::Named("schema_rows"), ArgType::Number),
358            ]),
359        ),
360    );
361}
362
363fn def_mutate(signatures: &mut SignaturesMap) {
364    let operand = ArgType::OneOf(vec![
365        ArgType::Identifier,
366        ArgType::Number,
367        ArgType::String,
368        ArgType::function("ymd_hms", Args::Ordered(vec![ArgType::Identifier])),
369        ArgType::function("dnanos", Args::Ordered(vec![ArgType::Identifier])),
370        ArgType::function("dmicros", Args::Ordered(vec![ArgType::Identifier])),
371        ArgType::function("dmillis", Args::Ordered(vec![ArgType::Identifier])),
372        ArgType::function("dsecs", Args::Ordered(vec![ArgType::Identifier])),
373        ArgType::function("nanos", Args::Ordered(vec![ArgType::Identifier])),
374        ArgType::function("micros", Args::Ordered(vec![ArgType::Identifier])),
375        ArgType::function("millis", Args::Ordered(vec![ArgType::Identifier])),
376        ArgType::function("secs", Args::Ordered(vec![ArgType::Identifier])),
377        ArgType::function(
378            "field",
379            Args::Ordered(vec![ArgType::Identifier, ArgType::Identifier]),
380        ),
381        ArgType::function("len", Args::Ordered(vec![ArgType::Identifier])),
382        ArgType::function("max", Args::Ordered(vec![ArgType::Identifier])),
383        ArgType::function("mean", Args::Ordered(vec![ArgType::Identifier])),
384        ArgType::function("median", Args::Ordered(vec![ArgType::Identifier])),
385        ArgType::function("min", Args::Ordered(vec![ArgType::Identifier])),
386        ArgType::function("row", Args::None),
387    ]);
388
389    let expr = ArgType::OneOf(vec![operand.clone(), ArgType::arith(operand)]);
390
391    signatures.insert(
392        "mutate",
393        Args::OneOrMore(ArgType::assign(ArgType::Identifier, expr)),
394    );
395}
396
397fn def_parquet(signatures: &mut SignaturesMap) {
398    signatures.insert(
399        "parquet",
400        Args::OneThenMore(
401            ArgType::String,
402            ArgType::assign(ArgType::Named("overwrite"), ArgType::Bool),
403        ),
404    );
405}
406
407fn def_relocate(signatures: &mut SignaturesMap) {
408    signatures.insert(
409        "relocate",
410        Args::OneOrMore(ArgType::OneOf(vec![
411            ArgType::Identifier,
412            ArgType::assign(ArgType::Named("after"), ArgType::Identifier),
413            ArgType::assign(ArgType::Named("before"), ArgType::Identifier),
414        ])),
415    );
416}
417
418fn def_rename(signatures: &mut SignaturesMap) {
419    signatures.insert(
420        "rename",
421        Args::OneOrMore(ArgType::assign(ArgType::Identifier, ArgType::Identifier)),
422    );
423}
424
425fn def_select(signatures: &mut SignaturesMap) {
426    let contains_fn = ArgType::function("contains", Args::Ordered(vec![ArgType::String]));
427    let ends_with_fn = ArgType::function("ends_with", Args::Ordered(vec![ArgType::String]));
428    let start_with_fn = ArgType::function("starts_with", Args::Ordered(vec![ArgType::String]));
429
430    signatures.insert(
431        "select",
432        Args::OneOrMore(ArgType::OneOf(vec![
433            ArgType::Identifier,
434            ArgType::assign(ArgType::Identifier, ArgType::Identifier),
435            contains_fn.clone(),
436            ArgType::negate(contains_fn),
437            ends_with_fn.clone(),
438            ArgType::negate(ends_with_fn),
439            start_with_fn.clone(),
440            ArgType::negate(start_with_fn),
441        ])),
442    );
443}
444
445fn def_summarize(signatures: &mut SignaturesMap) {
446    signatures.insert(
447        "summarize",
448        Args::OneOrMore(ArgType::Assign(
449            Box::new(ArgType::Identifier),
450            Box::new(ArgType::OneOf(vec![
451                ArgType::function("list", Args::Ordered(vec![ArgType::Identifier])),
452                ArgType::function("max", Args::Ordered(vec![ArgType::Identifier])),
453                ArgType::function("mean", Args::Ordered(vec![ArgType::Identifier])),
454                ArgType::function("median", Args::Ordered(vec![ArgType::Identifier])),
455                ArgType::function("min", Args::Ordered(vec![ArgType::Identifier])),
456                ArgType::function("n", Args::None),
457                ArgType::function(
458                    "quantile",
459                    Args::Ordered(vec![ArgType::Identifier, ArgType::Number]),
460                ),
461                ArgType::function("sd", Args::Ordered(vec![ArgType::Identifier])),
462                ArgType::function("sum", Args::Ordered(vec![ArgType::Identifier])),
463                ArgType::function("var", Args::Ordered(vec![ArgType::Identifier])),
464            ])),
465        )),
466    );
467}
468
469fn def_show(signatures: &mut SignaturesMap) {
470    signatures.insert("show", Args::None);
471}
472
473fn def_unnest(signatures: &mut SignaturesMap) {
474    signatures.insert("unnest", Args::OneOrMore(ArgType::Identifier));
475}