jj_lib/
dsl_util.rs

1// Copyright 2020-2024 The Jujutsu Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Domain-specific language helpers.
16
17use std::ascii;
18use std::collections::HashMap;
19use std::fmt;
20use std::slice;
21
22use itertools::Itertools as _;
23use pest::iterators::Pair;
24use pest::iterators::Pairs;
25use pest::RuleType;
26
27/// Manages diagnostic messages emitted during parsing.
28///
29/// `T` is usually a parse error type of the language, which contains a message
30/// and source span of 'static lifetime.
31#[derive(Debug)]
32pub struct Diagnostics<T> {
33    // This might be extended to [{ kind: Warning|Error, message: T }, ..].
34    diagnostics: Vec<T>,
35}
36
37impl<T> Diagnostics<T> {
38    /// Creates new empty diagnostics collector.
39    pub fn new() -> Self {
40        Diagnostics {
41            diagnostics: Vec::new(),
42        }
43    }
44
45    /// Returns `true` if there are no diagnostic messages.
46    pub fn is_empty(&self) -> bool {
47        self.diagnostics.is_empty()
48    }
49
50    /// Returns the number of diagnostic messages.
51    pub fn len(&self) -> usize {
52        self.diagnostics.len()
53    }
54
55    /// Returns iterator over diagnostic messages.
56    pub fn iter(&self) -> slice::Iter<'_, T> {
57        self.diagnostics.iter()
58    }
59
60    /// Adds a diagnostic message of warning level.
61    pub fn add_warning(&mut self, diag: T) {
62        self.diagnostics.push(diag);
63    }
64
65    /// Moves diagnostic messages of different type (such as fileset warnings
66    /// emitted within `file()` revset.)
67    pub fn extend_with<U>(&mut self, diagnostics: Diagnostics<U>, mut f: impl FnMut(U) -> T) {
68        self.diagnostics
69            .extend(diagnostics.diagnostics.into_iter().map(&mut f));
70    }
71}
72
73impl<T> Default for Diagnostics<T> {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl<'a, T> IntoIterator for &'a Diagnostics<T> {
80    type Item = &'a T;
81    type IntoIter = slice::Iter<'a, T>;
82
83    fn into_iter(self) -> Self::IntoIter {
84        self.iter()
85    }
86}
87
88/// AST node without type or name checking.
89#[derive(Clone, Debug, Eq, PartialEq)]
90pub struct ExpressionNode<'i, T> {
91    /// Expression item such as identifier, literal, function call, etc.
92    pub kind: T,
93    /// Span of the node.
94    pub span: pest::Span<'i>,
95}
96
97impl<'i, T> ExpressionNode<'i, T> {
98    /// Wraps the given expression and span.
99    pub fn new(kind: T, span: pest::Span<'i>) -> Self {
100        ExpressionNode { kind, span }
101    }
102}
103
104/// Function call in AST.
105#[derive(Clone, Debug, Eq, PartialEq)]
106pub struct FunctionCallNode<'i, T> {
107    /// Function name.
108    pub name: &'i str,
109    /// Span of the function name.
110    pub name_span: pest::Span<'i>,
111    /// List of positional arguments.
112    pub args: Vec<ExpressionNode<'i, T>>,
113    /// List of keyword arguments.
114    pub keyword_args: Vec<KeywordArgument<'i, T>>,
115    /// Span of the arguments list.
116    pub args_span: pest::Span<'i>,
117}
118
119/// Keyword argument pair in AST.
120#[derive(Clone, Debug, Eq, PartialEq)]
121pub struct KeywordArgument<'i, T> {
122    /// Parameter name.
123    pub name: &'i str,
124    /// Span of the parameter name.
125    pub name_span: pest::Span<'i>,
126    /// Value expression.
127    pub value: ExpressionNode<'i, T>,
128}
129
130impl<'i, T> FunctionCallNode<'i, T> {
131    /// Number of arguments assuming named arguments are all unique.
132    pub fn arity(&self) -> usize {
133        self.args.len() + self.keyword_args.len()
134    }
135
136    /// Ensures that no arguments passed.
137    pub fn expect_no_arguments(&self) -> Result<(), InvalidArguments<'i>> {
138        let ([], []) = self.expect_arguments()?;
139        Ok(())
140    }
141
142    /// Extracts exactly N required arguments.
143    pub fn expect_exact_arguments<const N: usize>(
144        &self,
145    ) -> Result<&[ExpressionNode<'i, T>; N], InvalidArguments<'i>> {
146        let (args, []) = self.expect_arguments()?;
147        Ok(args)
148    }
149
150    /// Extracts N required arguments and remainders.
151    #[expect(clippy::type_complexity)]
152    pub fn expect_some_arguments<const N: usize>(
153        &self,
154    ) -> Result<(&[ExpressionNode<'i, T>; N], &[ExpressionNode<'i, T>]), InvalidArguments<'i>> {
155        self.ensure_no_keyword_arguments()?;
156        if self.args.len() >= N {
157            let (required, rest) = self.args.split_at(N);
158            Ok((required.try_into().unwrap(), rest))
159        } else {
160            Err(self.invalid_arguments_count(N, None))
161        }
162    }
163
164    /// Extracts N required arguments and M optional arguments.
165    #[expect(clippy::type_complexity)]
166    pub fn expect_arguments<const N: usize, const M: usize>(
167        &self,
168    ) -> Result<
169        (
170            &[ExpressionNode<'i, T>; N],
171            [Option<&ExpressionNode<'i, T>>; M],
172        ),
173        InvalidArguments<'i>,
174    > {
175        self.ensure_no_keyword_arguments()?;
176        let count_range = N..=(N + M);
177        if count_range.contains(&self.args.len()) {
178            let (required, rest) = self.args.split_at(N);
179            let mut optional = rest.iter().map(Some).collect_vec();
180            optional.resize(M, None);
181            Ok((
182                required.try_into().unwrap(),
183                optional.try_into().ok().unwrap(),
184            ))
185        } else {
186            let (min, max) = count_range.into_inner();
187            Err(self.invalid_arguments_count(min, Some(max)))
188        }
189    }
190
191    /// Extracts N required arguments and M optional arguments. Some of them can
192    /// be specified as keyword arguments.
193    ///
194    /// `names` is a list of parameter names. Unnamed positional arguments
195    /// should be padded with `""`.
196    #[expect(clippy::type_complexity)]
197    pub fn expect_named_arguments<const N: usize, const M: usize>(
198        &self,
199        names: &[&str],
200    ) -> Result<
201        (
202            [&ExpressionNode<'i, T>; N],
203            [Option<&ExpressionNode<'i, T>>; M],
204        ),
205        InvalidArguments<'i>,
206    > {
207        if self.keyword_args.is_empty() {
208            let (required, optional) = self.expect_arguments::<N, M>()?;
209            Ok((required.each_ref(), optional))
210        } else {
211            let (required, optional) = self.expect_named_arguments_vec(names, N, N + M)?;
212            Ok((
213                required.try_into().ok().unwrap(),
214                optional.try_into().ok().unwrap(),
215            ))
216        }
217    }
218
219    #[expect(clippy::type_complexity)]
220    fn expect_named_arguments_vec(
221        &self,
222        names: &[&str],
223        min: usize,
224        max: usize,
225    ) -> Result<
226        (
227            Vec<&ExpressionNode<'i, T>>,
228            Vec<Option<&ExpressionNode<'i, T>>>,
229        ),
230        InvalidArguments<'i>,
231    > {
232        assert!(names.len() <= max);
233
234        if self.args.len() > max {
235            return Err(self.invalid_arguments_count(min, Some(max)));
236        }
237        let mut extracted = Vec::with_capacity(max);
238        extracted.extend(self.args.iter().map(Some));
239        extracted.resize(max, None);
240
241        for arg in &self.keyword_args {
242            let name = arg.name;
243            let span = arg.name_span.start_pos().span(&arg.value.span.end_pos());
244            let pos = names.iter().position(|&n| n == name).ok_or_else(|| {
245                self.invalid_arguments(format!(r#"Unexpected keyword argument "{name}""#), span)
246            })?;
247            if extracted[pos].is_some() {
248                return Err(self.invalid_arguments(
249                    format!(r#"Got multiple values for keyword "{name}""#),
250                    span,
251                ));
252            }
253            extracted[pos] = Some(&arg.value);
254        }
255
256        let optional = extracted.split_off(min);
257        let required = extracted.into_iter().flatten().collect_vec();
258        if required.len() != min {
259            return Err(self.invalid_arguments_count(min, Some(max)));
260        }
261        Ok((required, optional))
262    }
263
264    fn ensure_no_keyword_arguments(&self) -> Result<(), InvalidArguments<'i>> {
265        if let (Some(first), Some(last)) = (self.keyword_args.first(), self.keyword_args.last()) {
266            let span = first.name_span.start_pos().span(&last.value.span.end_pos());
267            Err(self.invalid_arguments("Unexpected keyword arguments".to_owned(), span))
268        } else {
269            Ok(())
270        }
271    }
272
273    fn invalid_arguments(&self, message: String, span: pest::Span<'i>) -> InvalidArguments<'i> {
274        InvalidArguments {
275            name: self.name,
276            message,
277            span,
278        }
279    }
280
281    fn invalid_arguments_count(&self, min: usize, max: Option<usize>) -> InvalidArguments<'i> {
282        let message = match (min, max) {
283            (min, Some(max)) if min == max => format!("Expected {min} arguments"),
284            (min, Some(max)) => format!("Expected {min} to {max} arguments"),
285            (min, None) => format!("Expected at least {min} arguments"),
286        };
287        self.invalid_arguments(message, self.args_span)
288    }
289
290    fn invalid_arguments_count_with_arities(
291        &self,
292        arities: impl IntoIterator<Item = usize>,
293    ) -> InvalidArguments<'i> {
294        let message = format!("Expected {} arguments", arities.into_iter().join(", "));
295        self.invalid_arguments(message, self.args_span)
296    }
297}
298
299/// Unexpected number of arguments, or invalid combination of arguments.
300///
301/// This error is supposed to be converted to language-specific parse error
302/// type, where lifetime `'i` will be eliminated.
303#[derive(Clone, Debug)]
304pub struct InvalidArguments<'i> {
305    /// Function name.
306    pub name: &'i str,
307    /// Error message.
308    pub message: String,
309    /// Span of the bad arguments.
310    pub span: pest::Span<'i>,
311}
312
313/// Expression item that can be transformed recursively by using `folder: F`.
314pub trait FoldableExpression<'i>: Sized {
315    /// Transforms `self` by applying the `folder` to inner items.
316    fn fold<F>(self, folder: &mut F, span: pest::Span<'i>) -> Result<Self, F::Error>
317    where
318        F: ExpressionFolder<'i, Self> + ?Sized;
319}
320
321/// Visitor-like interface to transform AST nodes recursively.
322pub trait ExpressionFolder<'i, T: FoldableExpression<'i>> {
323    /// Transform error.
324    type Error;
325
326    /// Transforms the expression `node`. By default, inner items are
327    /// transformed recursively.
328    fn fold_expression(
329        &mut self,
330        node: ExpressionNode<'i, T>,
331    ) -> Result<ExpressionNode<'i, T>, Self::Error> {
332        let ExpressionNode { kind, span } = node;
333        let kind = kind.fold(self, span)?;
334        Ok(ExpressionNode { kind, span })
335    }
336
337    /// Transforms identifier.
338    fn fold_identifier(&mut self, name: &'i str, span: pest::Span<'i>) -> Result<T, Self::Error>;
339
340    /// Transforms function call.
341    fn fold_function_call(
342        &mut self,
343        function: Box<FunctionCallNode<'i, T>>,
344        span: pest::Span<'i>,
345    ) -> Result<T, Self::Error>;
346}
347
348/// Transforms list of `nodes` by using `folder`.
349pub fn fold_expression_nodes<'i, F, T>(
350    folder: &mut F,
351    nodes: Vec<ExpressionNode<'i, T>>,
352) -> Result<Vec<ExpressionNode<'i, T>>, F::Error>
353where
354    F: ExpressionFolder<'i, T> + ?Sized,
355    T: FoldableExpression<'i>,
356{
357    nodes
358        .into_iter()
359        .map(|node| folder.fold_expression(node))
360        .try_collect()
361}
362
363/// Transforms function call arguments by using `folder`.
364pub fn fold_function_call_args<'i, F, T>(
365    folder: &mut F,
366    function: FunctionCallNode<'i, T>,
367) -> Result<FunctionCallNode<'i, T>, F::Error>
368where
369    F: ExpressionFolder<'i, T> + ?Sized,
370    T: FoldableExpression<'i>,
371{
372    Ok(FunctionCallNode {
373        name: function.name,
374        name_span: function.name_span,
375        args: fold_expression_nodes(folder, function.args)?,
376        keyword_args: function
377            .keyword_args
378            .into_iter()
379            .map(|arg| {
380                Ok(KeywordArgument {
381                    name: arg.name,
382                    name_span: arg.name_span,
383                    value: folder.fold_expression(arg.value)?,
384                })
385            })
386            .try_collect()?,
387        args_span: function.args_span,
388    })
389}
390
391/// Helper to parse string literal.
392#[derive(Debug)]
393pub struct StringLiteralParser<R> {
394    /// String content part.
395    pub content_rule: R,
396    /// Escape sequence part including backslash character.
397    pub escape_rule: R,
398}
399
400impl<R: RuleType> StringLiteralParser<R> {
401    /// Parses the given string literal `pairs` into string.
402    pub fn parse(&self, pairs: Pairs<R>) -> String {
403        let mut result = String::new();
404        for part in pairs {
405            if part.as_rule() == self.content_rule {
406                result.push_str(part.as_str());
407            } else if part.as_rule() == self.escape_rule {
408                match &part.as_str()[1..] {
409                    "\"" => result.push('"'),
410                    "\\" => result.push('\\'),
411                    "t" => result.push('\t'),
412                    "r" => result.push('\r'),
413                    "n" => result.push('\n'),
414                    "0" => result.push('\0'),
415                    "e" => result.push('\x1b'),
416                    hex if hex.starts_with('x') => {
417                        result.push(char::from(
418                            u8::from_str_radix(&hex[1..], 16).expect("hex characters"),
419                        ));
420                    }
421                    char => panic!("invalid escape: \\{char:?}"),
422                }
423            } else {
424                panic!("unexpected part of string: {part:?}");
425            }
426        }
427        result
428    }
429}
430
431/// Escape special characters in the input
432pub fn escape_string(unescaped: &str) -> String {
433    let mut escaped = String::with_capacity(unescaped.len());
434    for c in unescaped.chars() {
435        match c {
436            '"' => escaped.push_str(r#"\""#),
437            '\\' => escaped.push_str(r#"\\"#),
438            '\t' => escaped.push_str(r#"\t"#),
439            '\r' => escaped.push_str(r#"\r"#),
440            '\n' => escaped.push_str(r#"\n"#),
441            '\0' => escaped.push_str(r#"\0"#),
442            c if c.is_ascii_control() => {
443                for b in ascii::escape_default(c as u8) {
444                    escaped.push(b as char);
445                }
446            }
447            c => escaped.push(c),
448        }
449    }
450    escaped
451}
452
453/// Helper to parse function call.
454#[derive(Debug)]
455pub struct FunctionCallParser<R> {
456    /// Function name.
457    pub function_name_rule: R,
458    /// List of positional and keyword arguments.
459    pub function_arguments_rule: R,
460    /// Pair of parameter name and value.
461    pub keyword_argument_rule: R,
462    /// Parameter name.
463    pub argument_name_rule: R,
464    /// Value expression.
465    pub argument_value_rule: R,
466}
467
468impl<R: RuleType> FunctionCallParser<R> {
469    /// Parses the given `pair` as function call.
470    pub fn parse<'i, T, E: From<InvalidArguments<'i>>>(
471        &self,
472        pair: Pair<'i, R>,
473        // parse_name can be defined for any Pair<'_, R>, but parse_value should
474        // be allowed to construct T by capturing Pair<'i, R>.
475        parse_name: impl Fn(Pair<'i, R>) -> Result<&'i str, E>,
476        parse_value: impl Fn(Pair<'i, R>) -> Result<ExpressionNode<'i, T>, E>,
477    ) -> Result<FunctionCallNode<'i, T>, E> {
478        let [name_pair, args_pair] = pair.into_inner().collect_array().unwrap();
479        assert_eq!(name_pair.as_rule(), self.function_name_rule);
480        assert_eq!(args_pair.as_rule(), self.function_arguments_rule);
481        let name_span = name_pair.as_span();
482        let args_span = args_pair.as_span();
483        let function_name = parse_name(name_pair)?;
484        let mut args = Vec::new();
485        let mut keyword_args = Vec::new();
486        for pair in args_pair.into_inner() {
487            let span = pair.as_span();
488            if pair.as_rule() == self.argument_value_rule {
489                if !keyword_args.is_empty() {
490                    return Err(InvalidArguments {
491                        name: function_name,
492                        message: "Positional argument follows keyword argument".to_owned(),
493                        span,
494                    }
495                    .into());
496                }
497                args.push(parse_value(pair)?);
498            } else if pair.as_rule() == self.keyword_argument_rule {
499                let [name_pair, value_pair] = pair.into_inner().collect_array().unwrap();
500                assert_eq!(name_pair.as_rule(), self.argument_name_rule);
501                assert_eq!(value_pair.as_rule(), self.argument_value_rule);
502                let name_span = name_pair.as_span();
503                let arg = KeywordArgument {
504                    name: parse_name(name_pair)?,
505                    name_span,
506                    value: parse_value(value_pair)?,
507                };
508                keyword_args.push(arg);
509            } else {
510                panic!("unexpected argument rule {pair:?}");
511            }
512        }
513        Ok(FunctionCallNode {
514            name: function_name,
515            name_span,
516            args,
517            keyword_args,
518            args_span,
519        })
520    }
521}
522
523/// Map of symbol and function aliases.
524#[derive(Clone, Debug, Default)]
525pub struct AliasesMap<P, V> {
526    symbol_aliases: HashMap<String, V>,
527    // name: [(params, defn)] (sorted by arity)
528    function_aliases: HashMap<String, Vec<(Vec<String>, V)>>,
529    // Parser type P helps prevent misuse of AliasesMap of different language.
530    parser: P,
531}
532
533impl<P, V> AliasesMap<P, V> {
534    /// Creates an empty aliases map with default-constructed parser.
535    pub fn new() -> Self
536    where
537        P: Default,
538    {
539        Self {
540            symbol_aliases: Default::default(),
541            function_aliases: Default::default(),
542            parser: Default::default(),
543        }
544    }
545
546    /// Adds new substitution rule `decl = defn`.
547    ///
548    /// Returns error if `decl` is invalid. The `defn` part isn't checked. A bad
549    /// `defn` will be reported when the alias is substituted.
550    pub fn insert(&mut self, decl: impl AsRef<str>, defn: impl Into<V>) -> Result<(), P::Error>
551    where
552        P: AliasDeclarationParser,
553    {
554        match self.parser.parse_declaration(decl.as_ref())? {
555            AliasDeclaration::Symbol(name) => {
556                self.symbol_aliases.insert(name, defn.into());
557            }
558            AliasDeclaration::Function(name, params) => {
559                let overloads = self.function_aliases.entry(name).or_default();
560                match overloads.binary_search_by_key(&params.len(), |(params, _)| params.len()) {
561                    Ok(i) => overloads[i] = (params, defn.into()),
562                    Err(i) => overloads.insert(i, (params, defn.into())),
563                }
564            }
565        }
566        Ok(())
567    }
568
569    /// Iterates symbol names in arbitrary order.
570    pub fn symbol_names(&self) -> impl Iterator<Item = &str> {
571        self.symbol_aliases.keys().map(|n| n.as_ref())
572    }
573
574    /// Iterates function names in arbitrary order.
575    pub fn function_names(&self) -> impl Iterator<Item = &str> {
576        self.function_aliases.keys().map(|n| n.as_ref())
577    }
578
579    /// Looks up symbol alias by name. Returns identifier and definition text.
580    pub fn get_symbol(&self, name: &str) -> Option<(AliasId<'_>, &V)> {
581        self.symbol_aliases
582            .get_key_value(name)
583            .map(|(name, defn)| (AliasId::Symbol(name), defn))
584    }
585
586    /// Looks up function alias by name and arity. Returns identifier, list of
587    /// parameter names, and definition text.
588    pub fn get_function(&self, name: &str, arity: usize) -> Option<(AliasId<'_>, &[String], &V)> {
589        let overloads = self.get_function_overloads(name)?;
590        overloads.find_by_arity(arity)
591    }
592
593    /// Looks up function aliases by name.
594    fn get_function_overloads(&self, name: &str) -> Option<AliasFunctionOverloads<'_, V>> {
595        let (name, overloads) = self.function_aliases.get_key_value(name)?;
596        Some(AliasFunctionOverloads { name, overloads })
597    }
598}
599
600#[derive(Clone, Debug)]
601struct AliasFunctionOverloads<'a, V> {
602    name: &'a String,
603    overloads: &'a Vec<(Vec<String>, V)>,
604}
605
606impl<'a, V> AliasFunctionOverloads<'a, V> {
607    // TODO: Perhaps, V doesn't have to be captured, but "currently, all type
608    // parameters are required to be mentioned in the precise captures list" as
609    // of rustc 1.85.0.
610    fn arities(&self) -> impl DoubleEndedIterator<Item = usize> + ExactSizeIterator + use<'a, V> {
611        self.overloads.iter().map(|(params, _)| params.len())
612    }
613
614    fn min_arity(&self) -> usize {
615        self.arities().next().unwrap()
616    }
617
618    fn max_arity(&self) -> usize {
619        self.arities().next_back().unwrap()
620    }
621
622    fn find_by_arity(&self, arity: usize) -> Option<(AliasId<'a>, &'a [String], &'a V)> {
623        let index = self
624            .overloads
625            .binary_search_by_key(&arity, |(params, _)| params.len())
626            .ok()?;
627        let (params, defn) = &self.overloads[index];
628        // Exact parameter names aren't needed to identify a function, but they
629        // provide a better error indication. (e.g. "foo(x, y)" is easier to
630        // follow than "foo/2".)
631        Some((AliasId::Function(self.name, params), params, defn))
632    }
633}
634
635/// Borrowed reference to identify alias expression.
636#[derive(Clone, Copy, Debug, Eq, PartialEq)]
637pub enum AliasId<'a> {
638    /// Symbol name.
639    Symbol(&'a str),
640    /// Function name and parameter names.
641    Function(&'a str, &'a [String]),
642    /// Function parameter name.
643    Parameter(&'a str),
644}
645
646impl fmt::Display for AliasId<'_> {
647    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
648        match self {
649            AliasId::Symbol(name) => write!(f, "{name}"),
650            AliasId::Function(name, params) => {
651                write!(f, "{name}({params})", params = params.join(", "))
652            }
653            AliasId::Parameter(name) => write!(f, "{name}"),
654        }
655    }
656}
657
658/// Parsed declaration part of alias rule.
659#[derive(Clone, Debug)]
660pub enum AliasDeclaration {
661    /// Symbol name.
662    Symbol(String),
663    /// Function name and parameters.
664    Function(String, Vec<String>),
665}
666
667// AliasDeclarationParser and AliasDefinitionParser can be merged into a single
668// trait, but it's unclear whether doing that would simplify the abstraction.
669
670/// Parser for symbol and function alias declaration.
671pub trait AliasDeclarationParser {
672    /// Parse error type.
673    type Error;
674
675    /// Parses symbol or function name and parameters.
676    fn parse_declaration(&self, source: &str) -> Result<AliasDeclaration, Self::Error>;
677}
678
679/// Parser for symbol and function alias definition.
680pub trait AliasDefinitionParser {
681    /// Expression item type.
682    type Output<'i>;
683    /// Parse error type.
684    type Error;
685
686    /// Parses alias body.
687    fn parse_definition<'i>(
688        &self,
689        source: &'i str,
690    ) -> Result<ExpressionNode<'i, Self::Output<'i>>, Self::Error>;
691}
692
693/// Expression item that supports alias substitution.
694pub trait AliasExpandableExpression<'i>: FoldableExpression<'i> {
695    /// Wraps identifier.
696    fn identifier(name: &'i str) -> Self;
697    /// Wraps function call.
698    fn function_call(function: Box<FunctionCallNode<'i, Self>>) -> Self;
699    /// Wraps substituted expression.
700    fn alias_expanded(id: AliasId<'i>, subst: Box<ExpressionNode<'i, Self>>) -> Self;
701}
702
703/// Error that may occur during alias substitution.
704pub trait AliasExpandError: Sized {
705    /// Unexpected number of arguments, or invalid combination of arguments.
706    fn invalid_arguments(err: InvalidArguments<'_>) -> Self;
707    /// Recursion detected during alias substitution.
708    fn recursive_expansion(id: AliasId<'_>, span: pest::Span<'_>) -> Self;
709    /// Attaches alias trace to the current error.
710    fn within_alias_expansion(self, id: AliasId<'_>, span: pest::Span<'_>) -> Self;
711}
712
713/// Expands aliases recursively in tree of `T`.
714#[derive(Debug)]
715struct AliasExpander<'i, 'a, T, P> {
716    /// Alias symbols and functions that are globally available.
717    aliases_map: &'i AliasesMap<P, String>,
718    /// Local variables set in the outermost scope.
719    locals: &'a HashMap<&'i str, ExpressionNode<'i, T>>,
720    /// Stack of aliases and local parameters currently expanding.
721    states: Vec<AliasExpandingState<'i, T>>,
722}
723
724#[derive(Debug)]
725struct AliasExpandingState<'i, T> {
726    id: AliasId<'i>,
727    locals: HashMap<&'i str, ExpressionNode<'i, T>>,
728}
729
730impl<'i, T, P, E> AliasExpander<'i, '_, T, P>
731where
732    T: AliasExpandableExpression<'i> + Clone,
733    P: AliasDefinitionParser<Output<'i> = T, Error = E>,
734    E: AliasExpandError,
735{
736    /// Local variables available to the current scope.
737    fn current_locals(&self) -> &HashMap<&'i str, ExpressionNode<'i, T>> {
738        self.states.last().map_or(self.locals, |s| &s.locals)
739    }
740
741    fn expand_defn(
742        &mut self,
743        id: AliasId<'i>,
744        defn: &'i str,
745        locals: HashMap<&'i str, ExpressionNode<'i, T>>,
746        span: pest::Span<'i>,
747    ) -> Result<T, E> {
748        // The stack should be short, so let's simply do linear search.
749        if self.states.iter().any(|s| s.id == id) {
750            return Err(E::recursive_expansion(id, span));
751        }
752        self.states.push(AliasExpandingState { id, locals });
753        // Parsed defn could be cached if needed.
754        let result = self
755            .aliases_map
756            .parser
757            .parse_definition(defn)
758            .and_then(|node| self.fold_expression(node))
759            .map(|node| T::alias_expanded(id, Box::new(node)))
760            .map_err(|e| e.within_alias_expansion(id, span));
761        self.states.pop();
762        result
763    }
764}
765
766impl<'i, T, P, E> ExpressionFolder<'i, T> for AliasExpander<'i, '_, T, P>
767where
768    T: AliasExpandableExpression<'i> + Clone,
769    P: AliasDefinitionParser<Output<'i> = T, Error = E>,
770    E: AliasExpandError,
771{
772    type Error = E;
773
774    fn fold_identifier(&mut self, name: &'i str, span: pest::Span<'i>) -> Result<T, Self::Error> {
775        if let Some(subst) = self.current_locals().get(name) {
776            let id = AliasId::Parameter(name);
777            Ok(T::alias_expanded(id, Box::new(subst.clone())))
778        } else if let Some((id, defn)) = self.aliases_map.get_symbol(name) {
779            let locals = HashMap::new(); // Don't spill out the current scope
780            self.expand_defn(id, defn, locals, span)
781        } else {
782            Ok(T::identifier(name))
783        }
784    }
785
786    fn fold_function_call(
787        &mut self,
788        function: Box<FunctionCallNode<'i, T>>,
789        span: pest::Span<'i>,
790    ) -> Result<T, Self::Error> {
791        // For better error indication, builtin functions are shadowed by name,
792        // not by (name, arity).
793        if let Some(overloads) = self.aliases_map.get_function_overloads(function.name) {
794            // TODO: add support for keyword arguments
795            function
796                .ensure_no_keyword_arguments()
797                .map_err(E::invalid_arguments)?;
798            let Some((id, params, defn)) = overloads.find_by_arity(function.arity()) else {
799                let min = overloads.min_arity();
800                let max = overloads.max_arity();
801                let err = if max - min + 1 == overloads.arities().len() {
802                    function.invalid_arguments_count(min, Some(max))
803                } else {
804                    function.invalid_arguments_count_with_arities(overloads.arities())
805                };
806                return Err(E::invalid_arguments(err));
807            };
808            // Resolve arguments in the current scope, and pass them in to the alias
809            // expansion scope.
810            let args = fold_expression_nodes(self, function.args)?;
811            let locals = params.iter().map(|s| s.as_str()).zip(args).collect();
812            self.expand_defn(id, defn, locals, span)
813        } else {
814            let function = Box::new(fold_function_call_args(self, *function)?);
815            Ok(T::function_call(function))
816        }
817    }
818}
819
820/// Expands aliases recursively.
821pub fn expand_aliases<'i, T, P>(
822    node: ExpressionNode<'i, T>,
823    aliases_map: &'i AliasesMap<P, String>,
824) -> Result<ExpressionNode<'i, T>, P::Error>
825where
826    T: AliasExpandableExpression<'i> + Clone,
827    P: AliasDefinitionParser<Output<'i> = T>,
828    P::Error: AliasExpandError,
829{
830    expand_aliases_with_locals(node, aliases_map, &HashMap::new())
831}
832
833/// Expands aliases recursively with the outermost local variables.
834///
835/// Local variables are similar to alias symbols, but are scoped. Alias symbols
836/// are globally accessible from alias expressions, but local variables aren't.
837pub fn expand_aliases_with_locals<'i, T, P>(
838    node: ExpressionNode<'i, T>,
839    aliases_map: &'i AliasesMap<P, String>,
840    locals: &HashMap<&'i str, ExpressionNode<'i, T>>,
841) -> Result<ExpressionNode<'i, T>, P::Error>
842where
843    T: AliasExpandableExpression<'i> + Clone,
844    P: AliasDefinitionParser<Output<'i> = T>,
845    P::Error: AliasExpandError,
846{
847    let mut expander = AliasExpander {
848        aliases_map,
849        locals,
850        states: Vec::new(),
851    };
852    expander.fold_expression(node)
853}
854
855/// Collects similar names from the `candidates` list.
856pub fn collect_similar<I>(name: &str, candidates: I) -> Vec<String>
857where
858    I: IntoIterator,
859    I::Item: AsRef<str>,
860{
861    candidates
862        .into_iter()
863        .filter(|cand| {
864            // The parameter is borrowed from clap f5540d26
865            strsim::jaro(name, cand.as_ref()) > 0.7
866        })
867        .map(|s| s.as_ref().to_owned())
868        .sorted_unstable()
869        .collect()
870}
871
872#[cfg(test)]
873mod tests {
874    use super::*;
875
876    #[test]
877    fn test_expect_arguments() {
878        fn empty_span() -> pest::Span<'static> {
879            pest::Span::new("", 0, 0).unwrap()
880        }
881
882        fn function(
883            name: &'static str,
884            args: impl Into<Vec<ExpressionNode<'static, u32>>>,
885            keyword_args: impl Into<Vec<KeywordArgument<'static, u32>>>,
886        ) -> FunctionCallNode<'static, u32> {
887            FunctionCallNode {
888                name,
889                name_span: empty_span(),
890                args: args.into(),
891                keyword_args: keyword_args.into(),
892                args_span: empty_span(),
893            }
894        }
895
896        fn value(v: u32) -> ExpressionNode<'static, u32> {
897            ExpressionNode::new(v, empty_span())
898        }
899
900        fn keyword(name: &'static str, v: u32) -> KeywordArgument<'static, u32> {
901            KeywordArgument {
902                name,
903                name_span: empty_span(),
904                value: value(v),
905            }
906        }
907
908        let f = function("foo", [], []);
909        assert!(f.expect_no_arguments().is_ok());
910        assert!(f.expect_some_arguments::<0>().is_ok());
911        assert!(f.expect_arguments::<0, 0>().is_ok());
912        assert!(f.expect_named_arguments::<0, 0>(&[]).is_ok());
913
914        let f = function("foo", [value(0)], []);
915        assert!(f.expect_no_arguments().is_err());
916        assert_eq!(
917            f.expect_some_arguments::<0>().unwrap(),
918            (&[], [value(0)].as_slice())
919        );
920        assert_eq!(
921            f.expect_some_arguments::<1>().unwrap(),
922            (&[value(0)], [].as_slice())
923        );
924        assert!(f.expect_arguments::<0, 0>().is_err());
925        assert_eq!(
926            f.expect_arguments::<0, 1>().unwrap(),
927            (&[], [Some(&value(0))])
928        );
929        assert_eq!(f.expect_arguments::<1, 1>().unwrap(), (&[value(0)], [None]));
930        assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
931        assert_eq!(
932            f.expect_named_arguments::<0, 1>(&["a"]).unwrap(),
933            ([], [Some(&value(0))])
934        );
935        assert_eq!(
936            f.expect_named_arguments::<1, 0>(&["a"]).unwrap(),
937            ([&value(0)], [])
938        );
939
940        let f = function("foo", [], [keyword("a", 0)]);
941        assert!(f.expect_no_arguments().is_err());
942        assert!(f.expect_some_arguments::<1>().is_err());
943        assert!(f.expect_arguments::<0, 1>().is_err());
944        assert!(f.expect_arguments::<1, 0>().is_err());
945        assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
946        assert!(f.expect_named_arguments::<0, 1>(&[]).is_err());
947        assert!(f.expect_named_arguments::<1, 0>(&[]).is_err());
948        assert_eq!(
949            f.expect_named_arguments::<1, 0>(&["a"]).unwrap(),
950            ([&value(0)], [])
951        );
952        assert_eq!(
953            f.expect_named_arguments::<1, 1>(&["a", "b"]).unwrap(),
954            ([&value(0)], [None])
955        );
956        assert!(f.expect_named_arguments::<1, 1>(&["b", "a"]).is_err());
957
958        let f = function("foo", [value(0)], [keyword("a", 1), keyword("b", 2)]);
959        assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
960        assert!(f.expect_named_arguments::<1, 1>(&["a", "b"]).is_err());
961        assert_eq!(
962            f.expect_named_arguments::<1, 2>(&["c", "a", "b"]).unwrap(),
963            ([&value(0)], [Some(&value(1)), Some(&value(2))])
964        );
965        assert_eq!(
966            f.expect_named_arguments::<2, 1>(&["c", "b", "a"]).unwrap(),
967            ([&value(0), &value(2)], [Some(&value(1))])
968        );
969        assert_eq!(
970            f.expect_named_arguments::<0, 3>(&["c", "b", "a"]).unwrap(),
971            ([], [Some(&value(0)), Some(&value(2)), Some(&value(1))])
972        );
973
974        let f = function("foo", [], [keyword("a", 0), keyword("a", 1)]);
975        assert!(f.expect_named_arguments::<1, 1>(&["", "a"]).is_err());
976    }
977}