jj_lib/
dsl_util.rs

1// Copyright 2020-2024 The Jujutsu Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Domain-specific language helpers.
16
17use std::ascii;
18use std::collections::HashMap;
19use std::fmt;
20use std::slice;
21
22use itertools::Itertools as _;
23use pest::RuleType;
24use pest::iterators::Pair;
25use pest::iterators::Pairs;
26
27/// Manages diagnostic messages emitted during parsing.
28///
29/// `T` is usually a parse error type of the language, which contains a message
30/// and source span of 'static lifetime.
31#[derive(Debug)]
32pub struct Diagnostics<T> {
33    // This might be extended to [{ kind: Warning|Error, message: T }, ..].
34    diagnostics: Vec<T>,
35}
36
37impl<T> Diagnostics<T> {
38    /// Creates new empty diagnostics collector.
39    pub fn new() -> Self {
40        Self {
41            diagnostics: Vec::new(),
42        }
43    }
44
45    /// Returns `true` if there are no diagnostic messages.
46    pub fn is_empty(&self) -> bool {
47        self.diagnostics.is_empty()
48    }
49
50    /// Returns the number of diagnostic messages.
51    pub fn len(&self) -> usize {
52        self.diagnostics.len()
53    }
54
55    /// Returns iterator over diagnostic messages.
56    pub fn iter(&self) -> slice::Iter<'_, T> {
57        self.diagnostics.iter()
58    }
59
60    /// Adds a diagnostic message of warning level.
61    pub fn add_warning(&mut self, diag: T) {
62        self.diagnostics.push(diag);
63    }
64
65    /// Moves diagnostic messages of different type (such as fileset warnings
66    /// emitted within `file()` revset.)
67    pub fn extend_with<U>(&mut self, diagnostics: Diagnostics<U>, mut f: impl FnMut(U) -> T) {
68        self.diagnostics
69            .extend(diagnostics.diagnostics.into_iter().map(&mut f));
70    }
71}
72
73impl<T> Default for Diagnostics<T> {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl<'a, T> IntoIterator for &'a Diagnostics<T> {
80    type Item = &'a T;
81    type IntoIter = slice::Iter<'a, T>;
82
83    fn into_iter(self) -> Self::IntoIter {
84        self.iter()
85    }
86}
87
88/// AST node without type or name checking.
89#[derive(Clone, Debug, Eq, PartialEq)]
90pub struct ExpressionNode<'i, T> {
91    /// Expression item such as identifier, literal, function call, etc.
92    pub kind: T,
93    /// Span of the node.
94    pub span: pest::Span<'i>,
95}
96
97impl<'i, T> ExpressionNode<'i, T> {
98    /// Wraps the given expression and span.
99    pub fn new(kind: T, span: pest::Span<'i>) -> Self {
100        Self { kind, span }
101    }
102}
103
104/// Function call in AST.
105#[derive(Clone, Debug, Eq, PartialEq)]
106pub struct FunctionCallNode<'i, T> {
107    /// Function name.
108    pub name: &'i str,
109    /// Span of the function name.
110    pub name_span: pest::Span<'i>,
111    /// List of positional arguments.
112    pub args: Vec<ExpressionNode<'i, T>>,
113    /// List of keyword arguments.
114    pub keyword_args: Vec<KeywordArgument<'i, T>>,
115    /// Span of the arguments list.
116    pub args_span: pest::Span<'i>,
117}
118
119/// Keyword argument pair in AST.
120#[derive(Clone, Debug, Eq, PartialEq)]
121pub struct KeywordArgument<'i, T> {
122    /// Parameter name.
123    pub name: &'i str,
124    /// Span of the parameter name.
125    pub name_span: pest::Span<'i>,
126    /// Value expression.
127    pub value: ExpressionNode<'i, T>,
128}
129
130impl<'i, T> FunctionCallNode<'i, T> {
131    /// Number of arguments assuming named arguments are all unique.
132    pub fn arity(&self) -> usize {
133        self.args.len() + self.keyword_args.len()
134    }
135
136    /// Ensures that no arguments passed.
137    pub fn expect_no_arguments(&self) -> Result<(), InvalidArguments<'i>> {
138        let ([], []) = self.expect_arguments()?;
139        Ok(())
140    }
141
142    /// Extracts exactly N required arguments.
143    pub fn expect_exact_arguments<const N: usize>(
144        &self,
145    ) -> Result<&[ExpressionNode<'i, T>; N], InvalidArguments<'i>> {
146        let (args, []) = self.expect_arguments()?;
147        Ok(args)
148    }
149
150    /// Extracts N required arguments and remainders.
151    ///
152    /// This can be used to get all the positional arguments without requiring
153    /// any (N = 0):
154    /// ```ignore
155    /// let ([], content_nodes) = function.expect_some_arguments()?;
156    /// ```
157    /// Avoid accessing `function.args` directly, as that may allow keyword
158    /// arguments to be silently ignored.
159    #[expect(clippy::type_complexity)]
160    pub fn expect_some_arguments<const N: usize>(
161        &self,
162    ) -> Result<(&[ExpressionNode<'i, T>; N], &[ExpressionNode<'i, T>]), InvalidArguments<'i>> {
163        self.ensure_no_keyword_arguments()?;
164        if self.args.len() >= N {
165            let (required, rest) = self.args.split_at(N);
166            Ok((required.try_into().unwrap(), rest))
167        } else {
168            Err(self.invalid_arguments_count(N, None))
169        }
170    }
171
172    /// Extracts N required arguments and M optional arguments.
173    #[expect(clippy::type_complexity)]
174    pub fn expect_arguments<const N: usize, const M: usize>(
175        &self,
176    ) -> Result<
177        (
178            &[ExpressionNode<'i, T>; N],
179            [Option<&ExpressionNode<'i, T>>; M],
180        ),
181        InvalidArguments<'i>,
182    > {
183        self.ensure_no_keyword_arguments()?;
184        let count_range = N..=(N + M);
185        if count_range.contains(&self.args.len()) {
186            let (required, rest) = self.args.split_at(N);
187            let mut optional = rest.iter().map(Some).collect_vec();
188            optional.resize(M, None);
189            Ok((
190                required.try_into().unwrap(),
191                optional.try_into().ok().unwrap(),
192            ))
193        } else {
194            let (min, max) = count_range.into_inner();
195            Err(self.invalid_arguments_count(min, Some(max)))
196        }
197    }
198
199    /// Extracts N required arguments and M optional arguments. Some of them can
200    /// be specified as keyword arguments.
201    ///
202    /// `names` is a list of parameter names. Unnamed positional arguments
203    /// should be padded with `""`.
204    #[expect(clippy::type_complexity)]
205    pub fn expect_named_arguments<const N: usize, const M: usize>(
206        &self,
207        names: &[&str],
208    ) -> Result<
209        (
210            [&ExpressionNode<'i, T>; N],
211            [Option<&ExpressionNode<'i, T>>; M],
212        ),
213        InvalidArguments<'i>,
214    > {
215        if self.keyword_args.is_empty() {
216            let (required, optional) = self.expect_arguments::<N, M>()?;
217            Ok((required.each_ref(), optional))
218        } else {
219            let (required, optional) = self.expect_named_arguments_vec(names, N, N + M)?;
220            Ok((
221                required.try_into().ok().unwrap(),
222                optional.try_into().ok().unwrap(),
223            ))
224        }
225    }
226
227    #[expect(clippy::type_complexity)]
228    fn expect_named_arguments_vec(
229        &self,
230        names: &[&str],
231        min: usize,
232        max: usize,
233    ) -> Result<
234        (
235            Vec<&ExpressionNode<'i, T>>,
236            Vec<Option<&ExpressionNode<'i, T>>>,
237        ),
238        InvalidArguments<'i>,
239    > {
240        assert!(names.len() <= max);
241
242        if self.args.len() > max {
243            return Err(self.invalid_arguments_count(min, Some(max)));
244        }
245        let mut extracted = Vec::with_capacity(max);
246        extracted.extend(self.args.iter().map(Some));
247        extracted.resize(max, None);
248
249        for arg in &self.keyword_args {
250            let name = arg.name;
251            let span = arg.name_span.start_pos().span(&arg.value.span.end_pos());
252            let pos = names.iter().position(|&n| n == name).ok_or_else(|| {
253                self.invalid_arguments(format!(r#"Unexpected keyword argument "{name}""#), span)
254            })?;
255            if extracted[pos].is_some() {
256                return Err(self.invalid_arguments(
257                    format!(r#"Got multiple values for keyword "{name}""#),
258                    span,
259                ));
260            }
261            extracted[pos] = Some(&arg.value);
262        }
263
264        let optional = extracted.split_off(min);
265        let required = extracted.into_iter().flatten().collect_vec();
266        if required.len() != min {
267            return Err(self.invalid_arguments_count(min, Some(max)));
268        }
269        Ok((required, optional))
270    }
271
272    fn ensure_no_keyword_arguments(&self) -> Result<(), InvalidArguments<'i>> {
273        if let (Some(first), Some(last)) = (self.keyword_args.first(), self.keyword_args.last()) {
274            let span = first.name_span.start_pos().span(&last.value.span.end_pos());
275            Err(self.invalid_arguments("Unexpected keyword arguments".to_owned(), span))
276        } else {
277            Ok(())
278        }
279    }
280
281    fn invalid_arguments(&self, message: String, span: pest::Span<'i>) -> InvalidArguments<'i> {
282        InvalidArguments {
283            name: self.name,
284            message,
285            span,
286        }
287    }
288
289    fn invalid_arguments_count(&self, min: usize, max: Option<usize>) -> InvalidArguments<'i> {
290        let message = match (min, max) {
291            (min, Some(max)) if min == max => format!("Expected {min} arguments"),
292            (min, Some(max)) => format!("Expected {min} to {max} arguments"),
293            (min, None) => format!("Expected at least {min} arguments"),
294        };
295        self.invalid_arguments(message, self.args_span)
296    }
297
298    fn invalid_arguments_count_with_arities(
299        &self,
300        arities: impl IntoIterator<Item = usize>,
301    ) -> InvalidArguments<'i> {
302        let message = format!("Expected {} arguments", arities.into_iter().join(", "));
303        self.invalid_arguments(message, self.args_span)
304    }
305}
306
307/// Unexpected number of arguments, or invalid combination of arguments.
308///
309/// This error is supposed to be converted to language-specific parse error
310/// type, where lifetime `'i` will be eliminated.
311#[derive(Clone, Debug)]
312pub struct InvalidArguments<'i> {
313    /// Function name.
314    pub name: &'i str,
315    /// Error message.
316    pub message: String,
317    /// Span of the bad arguments.
318    pub span: pest::Span<'i>,
319}
320
321/// Expression item that can be transformed recursively by using `folder: F`.
322pub trait FoldableExpression<'i>: Sized {
323    /// Transforms `self` by applying the `folder` to inner items.
324    fn fold<F>(self, folder: &mut F, span: pest::Span<'i>) -> Result<Self, F::Error>
325    where
326        F: ExpressionFolder<'i, Self> + ?Sized;
327}
328
329/// Visitor-like interface to transform AST nodes recursively.
330pub trait ExpressionFolder<'i, T: FoldableExpression<'i>> {
331    /// Transform error.
332    type Error;
333
334    /// Transforms the expression `node`. By default, inner items are
335    /// transformed recursively.
336    fn fold_expression(
337        &mut self,
338        node: ExpressionNode<'i, T>,
339    ) -> Result<ExpressionNode<'i, T>, Self::Error> {
340        let ExpressionNode { kind, span } = node;
341        let kind = kind.fold(self, span)?;
342        Ok(ExpressionNode { kind, span })
343    }
344
345    /// Transforms identifier.
346    fn fold_identifier(&mut self, name: &'i str, span: pest::Span<'i>) -> Result<T, Self::Error>;
347
348    /// Transforms function call.
349    fn fold_function_call(
350        &mut self,
351        function: Box<FunctionCallNode<'i, T>>,
352        span: pest::Span<'i>,
353    ) -> Result<T, Self::Error>;
354}
355
356/// Transforms list of `nodes` by using `folder`.
357pub fn fold_expression_nodes<'i, F, T>(
358    folder: &mut F,
359    nodes: Vec<ExpressionNode<'i, T>>,
360) -> Result<Vec<ExpressionNode<'i, T>>, F::Error>
361where
362    F: ExpressionFolder<'i, T> + ?Sized,
363    T: FoldableExpression<'i>,
364{
365    nodes
366        .into_iter()
367        .map(|node| folder.fold_expression(node))
368        .try_collect()
369}
370
371/// Transforms function call arguments by using `folder`.
372pub fn fold_function_call_args<'i, F, T>(
373    folder: &mut F,
374    function: FunctionCallNode<'i, T>,
375) -> Result<FunctionCallNode<'i, T>, F::Error>
376where
377    F: ExpressionFolder<'i, T> + ?Sized,
378    T: FoldableExpression<'i>,
379{
380    Ok(FunctionCallNode {
381        name: function.name,
382        name_span: function.name_span,
383        args: fold_expression_nodes(folder, function.args)?,
384        keyword_args: function
385            .keyword_args
386            .into_iter()
387            .map(|arg| {
388                Ok(KeywordArgument {
389                    name: arg.name,
390                    name_span: arg.name_span,
391                    value: folder.fold_expression(arg.value)?,
392                })
393            })
394            .try_collect()?,
395        args_span: function.args_span,
396    })
397}
398
399/// Helper to parse string literal.
400#[derive(Debug)]
401pub struct StringLiteralParser<R> {
402    /// String content part.
403    pub content_rule: R,
404    /// Escape sequence part including backslash character.
405    pub escape_rule: R,
406}
407
408impl<R: RuleType> StringLiteralParser<R> {
409    /// Parses the given string literal `pairs` into string.
410    pub fn parse(&self, pairs: Pairs<R>) -> String {
411        let mut result = String::new();
412        for part in pairs {
413            if part.as_rule() == self.content_rule {
414                result.push_str(part.as_str());
415            } else if part.as_rule() == self.escape_rule {
416                match &part.as_str()[1..] {
417                    "\"" => result.push('"'),
418                    "\\" => result.push('\\'),
419                    "t" => result.push('\t'),
420                    "r" => result.push('\r'),
421                    "n" => result.push('\n'),
422                    "0" => result.push('\0'),
423                    "e" => result.push('\x1b'),
424                    hex if hex.starts_with('x') => {
425                        result.push(char::from(
426                            u8::from_str_radix(&hex[1..], 16).expect("hex characters"),
427                        ));
428                    }
429                    char => panic!("invalid escape: \\{char:?}"),
430                }
431            } else {
432                panic!("unexpected part of string: {part:?}");
433            }
434        }
435        result
436    }
437}
438
439/// Escape special characters in the input
440pub fn escape_string(unescaped: &str) -> String {
441    let mut escaped = String::with_capacity(unescaped.len());
442    for c in unescaped.chars() {
443        match c {
444            '"' => escaped.push_str(r#"\""#),
445            '\\' => escaped.push_str(r#"\\"#),
446            '\t' => escaped.push_str(r#"\t"#),
447            '\r' => escaped.push_str(r#"\r"#),
448            '\n' => escaped.push_str(r#"\n"#),
449            '\0' => escaped.push_str(r#"\0"#),
450            c if c.is_ascii_control() => {
451                for b in ascii::escape_default(c as u8) {
452                    escaped.push(b as char);
453                }
454            }
455            c => escaped.push(c),
456        }
457    }
458    escaped
459}
460
461/// Helper to parse function call.
462#[derive(Debug)]
463pub struct FunctionCallParser<R> {
464    /// Function name.
465    pub function_name_rule: R,
466    /// List of positional and keyword arguments.
467    pub function_arguments_rule: R,
468    /// Pair of parameter name and value.
469    pub keyword_argument_rule: R,
470    /// Parameter name.
471    pub argument_name_rule: R,
472    /// Value expression.
473    pub argument_value_rule: R,
474}
475
476impl<R: RuleType> FunctionCallParser<R> {
477    /// Parses the given `pair` as function call.
478    pub fn parse<'i, T, E: From<InvalidArguments<'i>>>(
479        &self,
480        pair: Pair<'i, R>,
481        // parse_name can be defined for any Pair<'_, R>, but parse_value should
482        // be allowed to construct T by capturing Pair<'i, R>.
483        parse_name: impl Fn(Pair<'i, R>) -> Result<&'i str, E>,
484        parse_value: impl Fn(Pair<'i, R>) -> Result<ExpressionNode<'i, T>, E>,
485    ) -> Result<FunctionCallNode<'i, T>, E> {
486        let [name_pair, args_pair] = pair.into_inner().collect_array().unwrap();
487        assert_eq!(name_pair.as_rule(), self.function_name_rule);
488        assert_eq!(args_pair.as_rule(), self.function_arguments_rule);
489        let name_span = name_pair.as_span();
490        let args_span = args_pair.as_span();
491        let function_name = parse_name(name_pair)?;
492        let mut args = Vec::new();
493        let mut keyword_args = Vec::new();
494        for pair in args_pair.into_inner() {
495            let span = pair.as_span();
496            if pair.as_rule() == self.argument_value_rule {
497                if !keyword_args.is_empty() {
498                    return Err(InvalidArguments {
499                        name: function_name,
500                        message: "Positional argument follows keyword argument".to_owned(),
501                        span,
502                    }
503                    .into());
504                }
505                args.push(parse_value(pair)?);
506            } else if pair.as_rule() == self.keyword_argument_rule {
507                let [name_pair, value_pair] = pair.into_inner().collect_array().unwrap();
508                assert_eq!(name_pair.as_rule(), self.argument_name_rule);
509                assert_eq!(value_pair.as_rule(), self.argument_value_rule);
510                let name_span = name_pair.as_span();
511                let arg = KeywordArgument {
512                    name: parse_name(name_pair)?,
513                    name_span,
514                    value: parse_value(value_pair)?,
515                };
516                keyword_args.push(arg);
517            } else {
518                panic!("unexpected argument rule {pair:?}");
519            }
520        }
521        Ok(FunctionCallNode {
522            name: function_name,
523            name_span,
524            args,
525            keyword_args,
526            args_span,
527        })
528    }
529}
530
531/// Map of symbol and function aliases.
532#[derive(Clone, Debug, Default)]
533pub struct AliasesMap<P, V> {
534    symbol_aliases: HashMap<String, V>,
535    // name: [(params, defn)] (sorted by arity)
536    function_aliases: HashMap<String, Vec<(Vec<String>, V)>>,
537    // Parser type P helps prevent misuse of AliasesMap of different language.
538    parser: P,
539}
540
541impl<P, V> AliasesMap<P, V> {
542    /// Creates an empty aliases map with default-constructed parser.
543    pub fn new() -> Self
544    where
545        P: Default,
546    {
547        Self {
548            symbol_aliases: Default::default(),
549            function_aliases: Default::default(),
550            parser: Default::default(),
551        }
552    }
553
554    /// Adds new substitution rule `decl = defn`.
555    ///
556    /// Returns error if `decl` is invalid. The `defn` part isn't checked. A bad
557    /// `defn` will be reported when the alias is substituted.
558    pub fn insert(&mut self, decl: impl AsRef<str>, defn: impl Into<V>) -> Result<(), P::Error>
559    where
560        P: AliasDeclarationParser,
561    {
562        match self.parser.parse_declaration(decl.as_ref())? {
563            AliasDeclaration::Symbol(name) => {
564                self.symbol_aliases.insert(name, defn.into());
565            }
566            AliasDeclaration::Function(name, params) => {
567                let overloads = self.function_aliases.entry(name).or_default();
568                match overloads.binary_search_by_key(&params.len(), |(params, _)| params.len()) {
569                    Ok(i) => overloads[i] = (params, defn.into()),
570                    Err(i) => overloads.insert(i, (params, defn.into())),
571                }
572            }
573        }
574        Ok(())
575    }
576
577    /// Iterates symbol names in arbitrary order.
578    pub fn symbol_names(&self) -> impl Iterator<Item = &str> {
579        self.symbol_aliases.keys().map(|n| n.as_ref())
580    }
581
582    /// Iterates function names in arbitrary order.
583    pub fn function_names(&self) -> impl Iterator<Item = &str> {
584        self.function_aliases.keys().map(|n| n.as_ref())
585    }
586
587    /// Looks up symbol alias by name. Returns identifier and definition text.
588    pub fn get_symbol(&self, name: &str) -> Option<(AliasId<'_>, &V)> {
589        self.symbol_aliases
590            .get_key_value(name)
591            .map(|(name, defn)| (AliasId::Symbol(name), defn))
592    }
593
594    /// Looks up function alias by name and arity. Returns identifier, list of
595    /// parameter names, and definition text.
596    pub fn get_function(&self, name: &str, arity: usize) -> Option<(AliasId<'_>, &[String], &V)> {
597        let overloads = self.get_function_overloads(name)?;
598        overloads.find_by_arity(arity)
599    }
600
601    /// Looks up function aliases by name.
602    fn get_function_overloads(&self, name: &str) -> Option<AliasFunctionOverloads<'_, V>> {
603        let (name, overloads) = self.function_aliases.get_key_value(name)?;
604        Some(AliasFunctionOverloads { name, overloads })
605    }
606}
607
608#[derive(Clone, Debug)]
609struct AliasFunctionOverloads<'a, V> {
610    name: &'a String,
611    overloads: &'a Vec<(Vec<String>, V)>,
612}
613
614impl<'a, V> AliasFunctionOverloads<'a, V> {
615    fn arities(&self) -> impl DoubleEndedIterator<Item = usize> + ExactSizeIterator {
616        self.overloads.iter().map(|(params, _)| params.len())
617    }
618
619    fn min_arity(&self) -> usize {
620        self.arities().next().unwrap()
621    }
622
623    fn max_arity(&self) -> usize {
624        self.arities().next_back().unwrap()
625    }
626
627    fn find_by_arity(&self, arity: usize) -> Option<(AliasId<'a>, &'a [String], &'a V)> {
628        let index = self
629            .overloads
630            .binary_search_by_key(&arity, |(params, _)| params.len())
631            .ok()?;
632        let (params, defn) = &self.overloads[index];
633        // Exact parameter names aren't needed to identify a function, but they
634        // provide a better error indication. (e.g. "foo(x, y)" is easier to
635        // follow than "foo/2".)
636        Some((AliasId::Function(self.name, params), params, defn))
637    }
638}
639
640/// Borrowed reference to identify alias expression.
641#[derive(Clone, Copy, Debug, Eq, PartialEq)]
642pub enum AliasId<'a> {
643    /// Symbol name.
644    Symbol(&'a str),
645    /// Function name and parameter names.
646    Function(&'a str, &'a [String]),
647    /// Function parameter name.
648    Parameter(&'a str),
649}
650
651impl fmt::Display for AliasId<'_> {
652    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
653        match self {
654            Self::Symbol(name) => write!(f, "{name}"),
655            Self::Function(name, params) => {
656                write!(f, "{name}({params})", params = params.join(", "))
657            }
658            Self::Parameter(name) => write!(f, "{name}"),
659        }
660    }
661}
662
663/// Parsed declaration part of alias rule.
664#[derive(Clone, Debug)]
665pub enum AliasDeclaration {
666    /// Symbol name.
667    Symbol(String),
668    /// Function name and parameters.
669    Function(String, Vec<String>),
670}
671
672// AliasDeclarationParser and AliasDefinitionParser can be merged into a single
673// trait, but it's unclear whether doing that would simplify the abstraction.
674
675/// Parser for symbol and function alias declaration.
676pub trait AliasDeclarationParser {
677    /// Parse error type.
678    type Error;
679
680    /// Parses symbol or function name and parameters.
681    fn parse_declaration(&self, source: &str) -> Result<AliasDeclaration, Self::Error>;
682}
683
684/// Parser for symbol and function alias definition.
685pub trait AliasDefinitionParser {
686    /// Expression item type.
687    type Output<'i>;
688    /// Parse error type.
689    type Error;
690
691    /// Parses alias body.
692    fn parse_definition<'i>(
693        &self,
694        source: &'i str,
695    ) -> Result<ExpressionNode<'i, Self::Output<'i>>, Self::Error>;
696}
697
698/// Expression item that supports alias substitution.
699pub trait AliasExpandableExpression<'i>: FoldableExpression<'i> {
700    /// Wraps identifier.
701    fn identifier(name: &'i str) -> Self;
702    /// Wraps function call.
703    fn function_call(function: Box<FunctionCallNode<'i, Self>>) -> Self;
704    /// Wraps substituted expression.
705    fn alias_expanded(id: AliasId<'i>, subst: Box<ExpressionNode<'i, Self>>) -> Self;
706}
707
708/// Error that may occur during alias substitution.
709pub trait AliasExpandError: Sized {
710    /// Unexpected number of arguments, or invalid combination of arguments.
711    fn invalid_arguments(err: InvalidArguments<'_>) -> Self;
712    /// Recursion detected during alias substitution.
713    fn recursive_expansion(id: AliasId<'_>, span: pest::Span<'_>) -> Self;
714    /// Attaches alias trace to the current error.
715    fn within_alias_expansion(self, id: AliasId<'_>, span: pest::Span<'_>) -> Self;
716}
717
718/// Expands aliases recursively in tree of `T`.
719#[derive(Debug)]
720struct AliasExpander<'i, 'a, T, P> {
721    /// Alias symbols and functions that are globally available.
722    aliases_map: &'i AliasesMap<P, String>,
723    /// Local variables set in the outermost scope.
724    locals: &'a HashMap<&'i str, ExpressionNode<'i, T>>,
725    /// Stack of aliases and local parameters currently expanding.
726    states: Vec<AliasExpandingState<'i, T>>,
727}
728
729#[derive(Debug)]
730struct AliasExpandingState<'i, T> {
731    id: AliasId<'i>,
732    locals: HashMap<&'i str, ExpressionNode<'i, T>>,
733}
734
735impl<'i, T, P, E> AliasExpander<'i, '_, T, P>
736where
737    T: AliasExpandableExpression<'i> + Clone,
738    P: AliasDefinitionParser<Output<'i> = T, Error = E>,
739    E: AliasExpandError,
740{
741    /// Local variables available to the current scope.
742    fn current_locals(&self) -> &HashMap<&'i str, ExpressionNode<'i, T>> {
743        self.states.last().map_or(self.locals, |s| &s.locals)
744    }
745
746    fn expand_defn(
747        &mut self,
748        id: AliasId<'i>,
749        defn: &'i str,
750        locals: HashMap<&'i str, ExpressionNode<'i, T>>,
751        span: pest::Span<'i>,
752    ) -> Result<T, E> {
753        // The stack should be short, so let's simply do linear search.
754        if self.states.iter().any(|s| s.id == id) {
755            return Err(E::recursive_expansion(id, span));
756        }
757        self.states.push(AliasExpandingState { id, locals });
758        // Parsed defn could be cached if needed.
759        let result = self
760            .aliases_map
761            .parser
762            .parse_definition(defn)
763            .and_then(|node| self.fold_expression(node))
764            .map(|node| T::alias_expanded(id, Box::new(node)))
765            .map_err(|e| e.within_alias_expansion(id, span));
766        self.states.pop();
767        result
768    }
769}
770
771impl<'i, T, P, E> ExpressionFolder<'i, T> for AliasExpander<'i, '_, T, P>
772where
773    T: AliasExpandableExpression<'i> + Clone,
774    P: AliasDefinitionParser<Output<'i> = T, Error = E>,
775    E: AliasExpandError,
776{
777    type Error = E;
778
779    fn fold_identifier(&mut self, name: &'i str, span: pest::Span<'i>) -> Result<T, Self::Error> {
780        if let Some(subst) = self.current_locals().get(name) {
781            let id = AliasId::Parameter(name);
782            Ok(T::alias_expanded(id, Box::new(subst.clone())))
783        } else if let Some((id, defn)) = self.aliases_map.get_symbol(name) {
784            let locals = HashMap::new(); // Don't spill out the current scope
785            self.expand_defn(id, defn, locals, span)
786        } else {
787            Ok(T::identifier(name))
788        }
789    }
790
791    fn fold_function_call(
792        &mut self,
793        function: Box<FunctionCallNode<'i, T>>,
794        span: pest::Span<'i>,
795    ) -> Result<T, Self::Error> {
796        // For better error indication, builtin functions are shadowed by name,
797        // not by (name, arity).
798        if let Some(overloads) = self.aliases_map.get_function_overloads(function.name) {
799            // TODO: add support for keyword arguments
800            function
801                .ensure_no_keyword_arguments()
802                .map_err(E::invalid_arguments)?;
803            let Some((id, params, defn)) = overloads.find_by_arity(function.arity()) else {
804                let min = overloads.min_arity();
805                let max = overloads.max_arity();
806                let err = if max - min + 1 == overloads.arities().len() {
807                    function.invalid_arguments_count(min, Some(max))
808                } else {
809                    function.invalid_arguments_count_with_arities(overloads.arities())
810                };
811                return Err(E::invalid_arguments(err));
812            };
813            // Resolve arguments in the current scope, and pass them in to the alias
814            // expansion scope.
815            let args = fold_expression_nodes(self, function.args)?;
816            let locals = params.iter().map(|s| s.as_str()).zip(args).collect();
817            self.expand_defn(id, defn, locals, span)
818        } else {
819            let function = Box::new(fold_function_call_args(self, *function)?);
820            Ok(T::function_call(function))
821        }
822    }
823}
824
825/// Expands aliases recursively.
826pub fn expand_aliases<'i, T, P>(
827    node: ExpressionNode<'i, T>,
828    aliases_map: &'i AliasesMap<P, String>,
829) -> Result<ExpressionNode<'i, T>, P::Error>
830where
831    T: AliasExpandableExpression<'i> + Clone,
832    P: AliasDefinitionParser<Output<'i> = T>,
833    P::Error: AliasExpandError,
834{
835    expand_aliases_with_locals(node, aliases_map, &HashMap::new())
836}
837
838/// Expands aliases recursively with the outermost local variables.
839///
840/// Local variables are similar to alias symbols, but are scoped. Alias symbols
841/// are globally accessible from alias expressions, but local variables aren't.
842pub fn expand_aliases_with_locals<'i, T, P>(
843    node: ExpressionNode<'i, T>,
844    aliases_map: &'i AliasesMap<P, String>,
845    locals: &HashMap<&'i str, ExpressionNode<'i, T>>,
846) -> Result<ExpressionNode<'i, T>, P::Error>
847where
848    T: AliasExpandableExpression<'i> + Clone,
849    P: AliasDefinitionParser<Output<'i> = T>,
850    P::Error: AliasExpandError,
851{
852    let mut expander = AliasExpander {
853        aliases_map,
854        locals,
855        states: Vec::new(),
856    };
857    expander.fold_expression(node)
858}
859
860/// Collects similar names from the `candidates` list.
861pub fn collect_similar<I>(name: &str, candidates: I) -> Vec<String>
862where
863    I: IntoIterator,
864    I::Item: AsRef<str>,
865{
866    candidates
867        .into_iter()
868        .filter(|cand| {
869            // The parameter is borrowed from clap f5540d26
870            strsim::jaro(name, cand.as_ref()) > 0.7
871        })
872        .map(|s| s.as_ref().to_owned())
873        .sorted_unstable()
874        .collect()
875}
876
877#[cfg(test)]
878mod tests {
879    use super::*;
880
881    #[test]
882    fn test_expect_arguments() {
883        fn empty_span() -> pest::Span<'static> {
884            pest::Span::new("", 0, 0).unwrap()
885        }
886
887        fn function(
888            name: &'static str,
889            args: impl Into<Vec<ExpressionNode<'static, u32>>>,
890            keyword_args: impl Into<Vec<KeywordArgument<'static, u32>>>,
891        ) -> FunctionCallNode<'static, u32> {
892            FunctionCallNode {
893                name,
894                name_span: empty_span(),
895                args: args.into(),
896                keyword_args: keyword_args.into(),
897                args_span: empty_span(),
898            }
899        }
900
901        fn value(v: u32) -> ExpressionNode<'static, u32> {
902            ExpressionNode::new(v, empty_span())
903        }
904
905        fn keyword(name: &'static str, v: u32) -> KeywordArgument<'static, u32> {
906            KeywordArgument {
907                name,
908                name_span: empty_span(),
909                value: value(v),
910            }
911        }
912
913        let f = function("foo", [], []);
914        assert!(f.expect_no_arguments().is_ok());
915        assert!(f.expect_some_arguments::<0>().is_ok());
916        assert!(f.expect_arguments::<0, 0>().is_ok());
917        assert!(f.expect_named_arguments::<0, 0>(&[]).is_ok());
918
919        let f = function("foo", [value(0)], []);
920        assert!(f.expect_no_arguments().is_err());
921        assert_eq!(
922            f.expect_some_arguments::<0>().unwrap(),
923            (&[], [value(0)].as_slice())
924        );
925        assert_eq!(
926            f.expect_some_arguments::<1>().unwrap(),
927            (&[value(0)], [].as_slice())
928        );
929        assert!(f.expect_arguments::<0, 0>().is_err());
930        assert_eq!(
931            f.expect_arguments::<0, 1>().unwrap(),
932            (&[], [Some(&value(0))])
933        );
934        assert_eq!(f.expect_arguments::<1, 1>().unwrap(), (&[value(0)], [None]));
935        assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
936        assert_eq!(
937            f.expect_named_arguments::<0, 1>(&["a"]).unwrap(),
938            ([], [Some(&value(0))])
939        );
940        assert_eq!(
941            f.expect_named_arguments::<1, 0>(&["a"]).unwrap(),
942            ([&value(0)], [])
943        );
944
945        let f = function("foo", [], [keyword("a", 0)]);
946        assert!(f.expect_no_arguments().is_err());
947        assert!(f.expect_some_arguments::<1>().is_err());
948        assert!(f.expect_arguments::<0, 1>().is_err());
949        assert!(f.expect_arguments::<1, 0>().is_err());
950        assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
951        assert!(f.expect_named_arguments::<0, 1>(&[]).is_err());
952        assert!(f.expect_named_arguments::<1, 0>(&[]).is_err());
953        assert_eq!(
954            f.expect_named_arguments::<1, 0>(&["a"]).unwrap(),
955            ([&value(0)], [])
956        );
957        assert_eq!(
958            f.expect_named_arguments::<1, 1>(&["a", "b"]).unwrap(),
959            ([&value(0)], [None])
960        );
961        assert!(f.expect_named_arguments::<1, 1>(&["b", "a"]).is_err());
962
963        let f = function("foo", [value(0)], [keyword("a", 1), keyword("b", 2)]);
964        assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
965        assert!(f.expect_named_arguments::<1, 1>(&["a", "b"]).is_err());
966        assert_eq!(
967            f.expect_named_arguments::<1, 2>(&["c", "a", "b"]).unwrap(),
968            ([&value(0)], [Some(&value(1)), Some(&value(2))])
969        );
970        assert_eq!(
971            f.expect_named_arguments::<2, 1>(&["c", "b", "a"]).unwrap(),
972            ([&value(0), &value(2)], [Some(&value(1))])
973        );
974        assert_eq!(
975            f.expect_named_arguments::<0, 3>(&["c", "b", "a"]).unwrap(),
976            ([], [Some(&value(0)), Some(&value(2)), Some(&value(1))])
977        );
978
979        let f = function("foo", [], [keyword("a", 0), keyword("a", 1)]);
980        assert!(f.expect_named_arguments::<1, 1>(&["", "a"]).is_err());
981    }
982}