jj_lib/
dsl_util.rs

1// Copyright 2020-2024 The Jujutsu Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Domain-specific language helpers.
16
17use std::array;
18use std::ascii;
19use std::collections::HashMap;
20use std::fmt;
21use std::slice;
22
23use itertools::Itertools as _;
24use pest::iterators::Pair;
25use pest::iterators::Pairs;
26use pest::RuleType;
27
28/// Manages diagnostic messages emitted during parsing.
29///
30/// `T` is usually a parse error type of the language, which contains a message
31/// and source span of 'static lifetime.
32#[derive(Debug)]
33pub struct Diagnostics<T> {
34    // This might be extended to [{ kind: Warning|Error, message: T }, ..].
35    diagnostics: Vec<T>,
36}
37
38impl<T> Diagnostics<T> {
39    /// Creates new empty diagnostics collector.
40    pub fn new() -> Self {
41        Diagnostics {
42            diagnostics: Vec::new(),
43        }
44    }
45
46    /// Returns `true` if there are no diagnostic messages.
47    pub fn is_empty(&self) -> bool {
48        self.diagnostics.is_empty()
49    }
50
51    /// Returns the number of diagnostic messages.
52    pub fn len(&self) -> usize {
53        self.diagnostics.len()
54    }
55
56    /// Returns iterator over diagnostic messages.
57    pub fn iter(&self) -> slice::Iter<'_, T> {
58        self.diagnostics.iter()
59    }
60
61    /// Adds a diagnostic message of warning level.
62    pub fn add_warning(&mut self, diag: T) {
63        self.diagnostics.push(diag);
64    }
65
66    /// Moves diagnostic messages of different type (such as fileset warnings
67    /// emitted within `file()` revset.)
68    pub fn extend_with<U>(&mut self, diagnostics: Diagnostics<U>, mut f: impl FnMut(U) -> T) {
69        self.diagnostics
70            .extend(diagnostics.diagnostics.into_iter().map(&mut f));
71    }
72}
73
74impl<T> Default for Diagnostics<T> {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl<'a, T> IntoIterator for &'a Diagnostics<T> {
81    type Item = &'a T;
82    type IntoIter = slice::Iter<'a, T>;
83
84    fn into_iter(self) -> Self::IntoIter {
85        self.iter()
86    }
87}
88
89/// AST node without type or name checking.
90#[derive(Clone, Debug, Eq, PartialEq)]
91pub struct ExpressionNode<'i, T> {
92    /// Expression item such as identifier, literal, function call, etc.
93    pub kind: T,
94    /// Span of the node.
95    pub span: pest::Span<'i>,
96}
97
98impl<'i, T> ExpressionNode<'i, T> {
99    /// Wraps the given expression and span.
100    pub fn new(kind: T, span: pest::Span<'i>) -> Self {
101        ExpressionNode { kind, span }
102    }
103}
104
105/// Function call in AST.
106#[derive(Clone, Debug, Eq, PartialEq)]
107pub struct FunctionCallNode<'i, T> {
108    /// Function name.
109    pub name: &'i str,
110    /// Span of the function name.
111    pub name_span: pest::Span<'i>,
112    /// List of positional arguments.
113    pub args: Vec<ExpressionNode<'i, T>>,
114    /// List of keyword arguments.
115    pub keyword_args: Vec<KeywordArgument<'i, T>>,
116    /// Span of the arguments list.
117    pub args_span: pest::Span<'i>,
118}
119
120/// Keyword argument pair in AST.
121#[derive(Clone, Debug, Eq, PartialEq)]
122pub struct KeywordArgument<'i, T> {
123    /// Parameter name.
124    pub name: &'i str,
125    /// Span of the parameter name.
126    pub name_span: pest::Span<'i>,
127    /// Value expression.
128    pub value: ExpressionNode<'i, T>,
129}
130
131impl<'i, T> FunctionCallNode<'i, T> {
132    /// Number of arguments assuming named arguments are all unique.
133    pub fn arity(&self) -> usize {
134        self.args.len() + self.keyword_args.len()
135    }
136
137    /// Ensures that no arguments passed.
138    pub fn expect_no_arguments(&self) -> Result<(), InvalidArguments<'i>> {
139        let ([], []) = self.expect_arguments()?;
140        Ok(())
141    }
142
143    /// Extracts exactly N required arguments.
144    pub fn expect_exact_arguments<const N: usize>(
145        &self,
146    ) -> Result<&[ExpressionNode<'i, T>; N], InvalidArguments<'i>> {
147        let (args, []) = self.expect_arguments()?;
148        Ok(args)
149    }
150
151    /// Extracts N required arguments and remainders.
152    #[allow(clippy::type_complexity)]
153    pub fn expect_some_arguments<const N: usize>(
154        &self,
155    ) -> Result<(&[ExpressionNode<'i, T>; N], &[ExpressionNode<'i, T>]), InvalidArguments<'i>> {
156        self.ensure_no_keyword_arguments()?;
157        if self.args.len() >= N {
158            let (required, rest) = self.args.split_at(N);
159            Ok((required.try_into().unwrap(), rest))
160        } else {
161            Err(self.invalid_arguments_count(N, None))
162        }
163    }
164
165    /// Extracts N required arguments and M optional arguments.
166    #[allow(clippy::type_complexity)]
167    pub fn expect_arguments<const N: usize, const M: usize>(
168        &self,
169    ) -> Result<
170        (
171            &[ExpressionNode<'i, T>; N],
172            [Option<&ExpressionNode<'i, T>>; M],
173        ),
174        InvalidArguments<'i>,
175    > {
176        self.ensure_no_keyword_arguments()?;
177        let count_range = N..=(N + M);
178        if count_range.contains(&self.args.len()) {
179            let (required, rest) = self.args.split_at(N);
180            let mut optional = rest.iter().map(Some).collect_vec();
181            optional.resize(M, None);
182            Ok((
183                required.try_into().unwrap(),
184                optional.try_into().ok().unwrap(),
185            ))
186        } else {
187            let (min, max) = count_range.into_inner();
188            Err(self.invalid_arguments_count(min, Some(max)))
189        }
190    }
191
192    /// Extracts N required arguments and M optional arguments. Some of them can
193    /// be specified as keyword arguments.
194    ///
195    /// `names` is a list of parameter names. Unnamed positional arguments
196    /// should be padded with `""`.
197    #[allow(clippy::type_complexity)]
198    pub fn expect_named_arguments<const N: usize, const M: usize>(
199        &self,
200        names: &[&str],
201    ) -> Result<
202        (
203            [&ExpressionNode<'i, T>; N],
204            [Option<&ExpressionNode<'i, T>>; M],
205        ),
206        InvalidArguments<'i>,
207    > {
208        if self.keyword_args.is_empty() {
209            let (required, optional) = self.expect_arguments::<N, M>()?;
210            // TODO: use .each_ref() if MSRV is bumped to 1.77.0
211            Ok((array::from_fn(|i| &required[i]), optional))
212        } else {
213            let (required, optional) = self.expect_named_arguments_vec(names, N, N + M)?;
214            Ok((
215                required.try_into().ok().unwrap(),
216                optional.try_into().ok().unwrap(),
217            ))
218        }
219    }
220
221    #[allow(clippy::type_complexity)]
222    fn expect_named_arguments_vec(
223        &self,
224        names: &[&str],
225        min: usize,
226        max: usize,
227    ) -> Result<
228        (
229            Vec<&ExpressionNode<'i, T>>,
230            Vec<Option<&ExpressionNode<'i, T>>>,
231        ),
232        InvalidArguments<'i>,
233    > {
234        assert!(names.len() <= max);
235
236        if self.args.len() > max {
237            return Err(self.invalid_arguments_count(min, Some(max)));
238        }
239        let mut extracted = Vec::with_capacity(max);
240        extracted.extend(self.args.iter().map(Some));
241        extracted.resize(max, None);
242
243        for arg in &self.keyword_args {
244            let name = arg.name;
245            let span = arg.name_span.start_pos().span(&arg.value.span.end_pos());
246            let pos = names.iter().position(|&n| n == name).ok_or_else(|| {
247                self.invalid_arguments(format!(r#"Unexpected keyword argument "{name}""#), span)
248            })?;
249            if extracted[pos].is_some() {
250                return Err(self.invalid_arguments(
251                    format!(r#"Got multiple values for keyword "{name}""#),
252                    span,
253                ));
254            }
255            extracted[pos] = Some(&arg.value);
256        }
257
258        let optional = extracted.split_off(min);
259        let required = extracted.into_iter().flatten().collect_vec();
260        if required.len() != min {
261            return Err(self.invalid_arguments_count(min, Some(max)));
262        }
263        Ok((required, optional))
264    }
265
266    fn ensure_no_keyword_arguments(&self) -> Result<(), InvalidArguments<'i>> {
267        if let (Some(first), Some(last)) = (self.keyword_args.first(), self.keyword_args.last()) {
268            let span = first.name_span.start_pos().span(&last.value.span.end_pos());
269            Err(self.invalid_arguments("Unexpected keyword arguments".to_owned(), span))
270        } else {
271            Ok(())
272        }
273    }
274
275    fn invalid_arguments(&self, message: String, span: pest::Span<'i>) -> InvalidArguments<'i> {
276        InvalidArguments {
277            name: self.name,
278            message,
279            span,
280        }
281    }
282
283    fn invalid_arguments_count(&self, min: usize, max: Option<usize>) -> InvalidArguments<'i> {
284        let message = match (min, max) {
285            (min, Some(max)) if min == max => format!("Expected {min} arguments"),
286            (min, Some(max)) => format!("Expected {min} to {max} arguments"),
287            (min, None) => format!("Expected at least {min} arguments"),
288        };
289        self.invalid_arguments(message, self.args_span)
290    }
291
292    fn invalid_arguments_count_with_arities(
293        &self,
294        arities: impl IntoIterator<Item = usize>,
295    ) -> InvalidArguments<'i> {
296        let message = format!("Expected {} arguments", arities.into_iter().join(", "));
297        self.invalid_arguments(message, self.args_span)
298    }
299}
300
301/// Unexpected number of arguments, or invalid combination of arguments.
302///
303/// This error is supposed to be converted to language-specific parse error
304/// type, where lifetime `'i` will be eliminated.
305#[derive(Clone, Debug)]
306pub struct InvalidArguments<'i> {
307    /// Function name.
308    pub name: &'i str,
309    /// Error message.
310    pub message: String,
311    /// Span of the bad arguments.
312    pub span: pest::Span<'i>,
313}
314
315/// Expression item that can be transformed recursively by using `folder: F`.
316pub trait FoldableExpression<'i>: Sized {
317    /// Transforms `self` by applying the `folder` to inner items.
318    fn fold<F>(self, folder: &mut F, span: pest::Span<'i>) -> Result<Self, F::Error>
319    where
320        F: ExpressionFolder<'i, Self> + ?Sized;
321}
322
323/// Visitor-like interface to transform AST nodes recursively.
324pub trait ExpressionFolder<'i, T: FoldableExpression<'i>> {
325    /// Transform error.
326    type Error;
327
328    /// Transforms the expression `node`. By default, inner items are
329    /// transformed recursively.
330    fn fold_expression(
331        &mut self,
332        node: ExpressionNode<'i, T>,
333    ) -> Result<ExpressionNode<'i, T>, Self::Error> {
334        let ExpressionNode { kind, span } = node;
335        let kind = kind.fold(self, span)?;
336        Ok(ExpressionNode { kind, span })
337    }
338
339    /// Transforms identifier.
340    fn fold_identifier(&mut self, name: &'i str, span: pest::Span<'i>) -> Result<T, Self::Error>;
341
342    /// Transforms function call.
343    fn fold_function_call(
344        &mut self,
345        function: Box<FunctionCallNode<'i, T>>,
346        span: pest::Span<'i>,
347    ) -> Result<T, Self::Error>;
348}
349
350/// Transforms list of `nodes` by using `folder`.
351pub fn fold_expression_nodes<'i, F, T>(
352    folder: &mut F,
353    nodes: Vec<ExpressionNode<'i, T>>,
354) -> Result<Vec<ExpressionNode<'i, T>>, F::Error>
355where
356    F: ExpressionFolder<'i, T> + ?Sized,
357    T: FoldableExpression<'i>,
358{
359    nodes
360        .into_iter()
361        .map(|node| folder.fold_expression(node))
362        .try_collect()
363}
364
365/// Transforms function call arguments by using `folder`.
366pub fn fold_function_call_args<'i, F, T>(
367    folder: &mut F,
368    function: FunctionCallNode<'i, T>,
369) -> Result<FunctionCallNode<'i, T>, F::Error>
370where
371    F: ExpressionFolder<'i, T> + ?Sized,
372    T: FoldableExpression<'i>,
373{
374    Ok(FunctionCallNode {
375        name: function.name,
376        name_span: function.name_span,
377        args: fold_expression_nodes(folder, function.args)?,
378        keyword_args: function
379            .keyword_args
380            .into_iter()
381            .map(|arg| {
382                Ok(KeywordArgument {
383                    name: arg.name,
384                    name_span: arg.name_span,
385                    value: folder.fold_expression(arg.value)?,
386                })
387            })
388            .try_collect()?,
389        args_span: function.args_span,
390    })
391}
392
393/// Helper to parse string literal.
394#[derive(Debug)]
395pub struct StringLiteralParser<R> {
396    /// String content part.
397    pub content_rule: R,
398    /// Escape sequence part including backslash character.
399    pub escape_rule: R,
400}
401
402impl<R: RuleType> StringLiteralParser<R> {
403    /// Parses the given string literal `pairs` into string.
404    pub fn parse(&self, pairs: Pairs<R>) -> String {
405        let mut result = String::new();
406        for part in pairs {
407            if part.as_rule() == self.content_rule {
408                result.push_str(part.as_str());
409            } else if part.as_rule() == self.escape_rule {
410                match &part.as_str()[1..] {
411                    "\"" => result.push('"'),
412                    "\\" => result.push('\\'),
413                    "t" => result.push('\t'),
414                    "r" => result.push('\r'),
415                    "n" => result.push('\n'),
416                    "0" => result.push('\0'),
417                    "e" => result.push('\x1b'),
418                    hex if hex.starts_with('x') => {
419                        result.push(char::from(
420                            u8::from_str_radix(&hex[1..], 16).expect("hex characters"),
421                        ));
422                    }
423                    char => panic!("invalid escape: \\{char:?}"),
424                }
425            } else {
426                panic!("unexpected part of string: {part:?}");
427            }
428        }
429        result
430    }
431}
432
433/// Escape special characters in the input
434pub fn escape_string(unescaped: &str) -> String {
435    let mut escaped = String::with_capacity(unescaped.len());
436    for c in unescaped.chars() {
437        match c {
438            '"' => escaped.push_str(r#"\""#),
439            '\\' => escaped.push_str(r#"\\"#),
440            '\t' => escaped.push_str(r#"\t"#),
441            '\r' => escaped.push_str(r#"\r"#),
442            '\n' => escaped.push_str(r#"\n"#),
443            '\0' => escaped.push_str(r#"\0"#),
444            c if c.is_ascii_control() => {
445                for b in ascii::escape_default(c as u8) {
446                    escaped.push(b as char);
447                }
448            }
449            c => escaped.push(c),
450        }
451    }
452    escaped
453}
454
455/// Helper to parse function call.
456#[derive(Debug)]
457pub struct FunctionCallParser<R> {
458    /// Function name.
459    pub function_name_rule: R,
460    /// List of positional and keyword arguments.
461    pub function_arguments_rule: R,
462    /// Pair of parameter name and value.
463    pub keyword_argument_rule: R,
464    /// Parameter name.
465    pub argument_name_rule: R,
466    /// Value expression.
467    pub argument_value_rule: R,
468}
469
470impl<R: RuleType> FunctionCallParser<R> {
471    /// Parses the given `pair` as function call.
472    pub fn parse<'i, T, E: From<InvalidArguments<'i>>>(
473        &self,
474        pair: Pair<'i, R>,
475        // parse_name can be defined for any Pair<'_, R>, but parse_value should
476        // be allowed to construct T by capturing Pair<'i, R>.
477        parse_name: impl Fn(Pair<'i, R>) -> Result<&'i str, E>,
478        parse_value: impl Fn(Pair<'i, R>) -> Result<ExpressionNode<'i, T>, E>,
479    ) -> Result<FunctionCallNode<'i, T>, E> {
480        let (name_pair, args_pair) = pair.into_inner().collect_tuple().unwrap();
481        assert_eq!(name_pair.as_rule(), self.function_name_rule);
482        assert_eq!(args_pair.as_rule(), self.function_arguments_rule);
483        let name_span = name_pair.as_span();
484        let args_span = args_pair.as_span();
485        let function_name = parse_name(name_pair)?;
486        let mut args = Vec::new();
487        let mut keyword_args = Vec::new();
488        for pair in args_pair.into_inner() {
489            let span = pair.as_span();
490            if pair.as_rule() == self.argument_value_rule {
491                if !keyword_args.is_empty() {
492                    return Err(InvalidArguments {
493                        name: function_name,
494                        message: "Positional argument follows keyword argument".to_owned(),
495                        span,
496                    }
497                    .into());
498                }
499                args.push(parse_value(pair)?);
500            } else if pair.as_rule() == self.keyword_argument_rule {
501                let (name_pair, value_pair) = pair.into_inner().collect_tuple().unwrap();
502                assert_eq!(name_pair.as_rule(), self.argument_name_rule);
503                assert_eq!(value_pair.as_rule(), self.argument_value_rule);
504                let name_span = name_pair.as_span();
505                let arg = KeywordArgument {
506                    name: parse_name(name_pair)?,
507                    name_span,
508                    value: parse_value(value_pair)?,
509                };
510                keyword_args.push(arg);
511            } else {
512                panic!("unexpected argument rule {pair:?}");
513            }
514        }
515        Ok(FunctionCallNode {
516            name: function_name,
517            name_span,
518            args,
519            keyword_args,
520            args_span,
521        })
522    }
523}
524
525/// Map of symbol and function aliases.
526#[derive(Clone, Debug, Default)]
527pub struct AliasesMap<P, V> {
528    symbol_aliases: HashMap<String, V>,
529    // name: [(params, defn)] (sorted by arity)
530    function_aliases: HashMap<String, Vec<(Vec<String>, V)>>,
531    // Parser type P helps prevent misuse of AliasesMap of different language.
532    parser: P,
533}
534
535impl<P, V> AliasesMap<P, V> {
536    /// Creates an empty aliases map with default-constructed parser.
537    pub fn new() -> Self
538    where
539        P: Default,
540    {
541        Self {
542            symbol_aliases: Default::default(),
543            function_aliases: Default::default(),
544            parser: Default::default(),
545        }
546    }
547
548    /// Adds new substitution rule `decl = defn`.
549    ///
550    /// Returns error if `decl` is invalid. The `defn` part isn't checked. A bad
551    /// `defn` will be reported when the alias is substituted.
552    pub fn insert(&mut self, decl: impl AsRef<str>, defn: impl Into<V>) -> Result<(), P::Error>
553    where
554        P: AliasDeclarationParser,
555    {
556        match self.parser.parse_declaration(decl.as_ref())? {
557            AliasDeclaration::Symbol(name) => {
558                self.symbol_aliases.insert(name, defn.into());
559            }
560            AliasDeclaration::Function(name, params) => {
561                let overloads = self.function_aliases.entry(name).or_default();
562                match overloads.binary_search_by_key(&params.len(), |(params, _)| params.len()) {
563                    Ok(i) => overloads[i] = (params, defn.into()),
564                    Err(i) => overloads.insert(i, (params, defn.into())),
565                }
566            }
567        }
568        Ok(())
569    }
570
571    /// Iterates symbol names in arbitrary order.
572    pub fn symbol_names(&self) -> impl Iterator<Item = &str> {
573        self.symbol_aliases.keys().map(|n| n.as_ref())
574    }
575
576    /// Iterates function names in arbitrary order.
577    pub fn function_names(&self) -> impl Iterator<Item = &str> {
578        self.function_aliases.keys().map(|n| n.as_ref())
579    }
580
581    /// Looks up symbol alias by name. Returns identifier and definition text.
582    pub fn get_symbol(&self, name: &str) -> Option<(AliasId<'_>, &V)> {
583        self.symbol_aliases
584            .get_key_value(name)
585            .map(|(name, defn)| (AliasId::Symbol(name), defn))
586    }
587
588    /// Looks up function alias by name and arity. Returns identifier, list of
589    /// parameter names, and definition text.
590    pub fn get_function(&self, name: &str, arity: usize) -> Option<(AliasId<'_>, &[String], &V)> {
591        let overloads = self.get_function_overloads(name)?;
592        overloads.find_by_arity(arity)
593    }
594
595    /// Looks up function aliases by name.
596    fn get_function_overloads(&self, name: &str) -> Option<AliasFunctionOverloads<'_, V>> {
597        let (name, overloads) = self.function_aliases.get_key_value(name)?;
598        Some(AliasFunctionOverloads { name, overloads })
599    }
600}
601
602#[derive(Clone, Debug)]
603struct AliasFunctionOverloads<'a, V> {
604    name: &'a String,
605    overloads: &'a Vec<(Vec<String>, V)>,
606}
607
608impl<'a, V> AliasFunctionOverloads<'a, V> {
609    fn arities(&self) -> impl DoubleEndedIterator<Item = usize> + ExactSizeIterator + 'a {
610        self.overloads.iter().map(|(params, _)| params.len())
611    }
612
613    fn min_arity(&self) -> usize {
614        self.arities().next().unwrap()
615    }
616
617    fn max_arity(&self) -> usize {
618        self.arities().next_back().unwrap()
619    }
620
621    fn find_by_arity(&self, arity: usize) -> Option<(AliasId<'a>, &'a [String], &'a V)> {
622        let index = self
623            .overloads
624            .binary_search_by_key(&arity, |(params, _)| params.len())
625            .ok()?;
626        let (params, defn) = &self.overloads[index];
627        // Exact parameter names aren't needed to identify a function, but they
628        // provide a better error indication. (e.g. "foo(x, y)" is easier to
629        // follow than "foo/2".)
630        Some((AliasId::Function(self.name, params), params, defn))
631    }
632}
633
634/// Borrowed reference to identify alias expression.
635#[derive(Clone, Copy, Debug, Eq, PartialEq)]
636pub enum AliasId<'a> {
637    /// Symbol name.
638    Symbol(&'a str),
639    /// Function name and parameter names.
640    Function(&'a str, &'a [String]),
641    /// Function parameter name.
642    Parameter(&'a str),
643}
644
645impl fmt::Display for AliasId<'_> {
646    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
647        match self {
648            AliasId::Symbol(name) => write!(f, "{name}"),
649            AliasId::Function(name, params) => {
650                write!(f, "{name}({params})", params = params.join(", "))
651            }
652            AliasId::Parameter(name) => write!(f, "{name}"),
653        }
654    }
655}
656
657/// Parsed declaration part of alias rule.
658#[derive(Clone, Debug)]
659pub enum AliasDeclaration {
660    /// Symbol name.
661    Symbol(String),
662    /// Function name and parameters.
663    Function(String, Vec<String>),
664}
665
666// AliasDeclarationParser and AliasDefinitionParser can be merged into a single
667// trait, but it's unclear whether doing that would simplify the abstraction.
668
669/// Parser for symbol and function alias declaration.
670pub trait AliasDeclarationParser {
671    /// Parse error type.
672    type Error;
673
674    /// Parses symbol or function name and parameters.
675    fn parse_declaration(&self, source: &str) -> Result<AliasDeclaration, Self::Error>;
676}
677
678/// Parser for symbol and function alias definition.
679pub trait AliasDefinitionParser {
680    /// Expression item type.
681    type Output<'i>;
682    /// Parse error type.
683    type Error;
684
685    /// Parses alias body.
686    fn parse_definition<'i>(
687        &self,
688        source: &'i str,
689    ) -> Result<ExpressionNode<'i, Self::Output<'i>>, Self::Error>;
690}
691
692/// Expression item that supports alias substitution.
693pub trait AliasExpandableExpression<'i>: FoldableExpression<'i> {
694    /// Wraps identifier.
695    fn identifier(name: &'i str) -> Self;
696    /// Wraps function call.
697    fn function_call(function: Box<FunctionCallNode<'i, Self>>) -> Self;
698    /// Wraps substituted expression.
699    fn alias_expanded(id: AliasId<'i>, subst: Box<ExpressionNode<'i, Self>>) -> Self;
700}
701
702/// Error that may occur during alias substitution.
703pub trait AliasExpandError: Sized {
704    /// Unexpected number of arguments, or invalid combination of arguments.
705    fn invalid_arguments(err: InvalidArguments<'_>) -> Self;
706    /// Recursion detected during alias substitution.
707    fn recursive_expansion(id: AliasId<'_>, span: pest::Span<'_>) -> Self;
708    /// Attaches alias trace to the current error.
709    fn within_alias_expansion(self, id: AliasId<'_>, span: pest::Span<'_>) -> Self;
710}
711
712/// Expands aliases recursively in tree of `T`.
713#[derive(Debug)]
714struct AliasExpander<'i, T, P> {
715    /// Alias symbols and functions that are globally available.
716    aliases_map: &'i AliasesMap<P, String>,
717    /// Stack of aliases and local parameters currently expanding.
718    states: Vec<AliasExpandingState<'i, T>>,
719}
720
721#[derive(Debug)]
722struct AliasExpandingState<'i, T> {
723    id: AliasId<'i>,
724    locals: HashMap<&'i str, ExpressionNode<'i, T>>,
725}
726
727impl<'i, T, P, E> AliasExpander<'i, T, P>
728where
729    T: AliasExpandableExpression<'i> + Clone,
730    P: AliasDefinitionParser<Output<'i> = T, Error = E>,
731    E: AliasExpandError,
732{
733    fn expand_defn(
734        &mut self,
735        id: AliasId<'i>,
736        defn: &'i str,
737        locals: HashMap<&'i str, ExpressionNode<'i, T>>,
738        span: pest::Span<'i>,
739    ) -> Result<T, E> {
740        // The stack should be short, so let's simply do linear search.
741        if self.states.iter().any(|s| s.id == id) {
742            return Err(E::recursive_expansion(id, span));
743        }
744        self.states.push(AliasExpandingState { id, locals });
745        // Parsed defn could be cached if needed.
746        let result = self
747            .aliases_map
748            .parser
749            .parse_definition(defn)
750            .and_then(|node| self.fold_expression(node))
751            .map(|node| T::alias_expanded(id, Box::new(node)))
752            .map_err(|e| e.within_alias_expansion(id, span));
753        self.states.pop();
754        result
755    }
756}
757
758impl<'i, T, P, E> ExpressionFolder<'i, T> for AliasExpander<'i, T, P>
759where
760    T: AliasExpandableExpression<'i> + Clone,
761    P: AliasDefinitionParser<Output<'i> = T, Error = E>,
762    E: AliasExpandError,
763{
764    type Error = E;
765
766    fn fold_identifier(&mut self, name: &'i str, span: pest::Span<'i>) -> Result<T, Self::Error> {
767        if let Some(subst) = self.states.last().and_then(|s| s.locals.get(name)) {
768            let id = AliasId::Parameter(name);
769            Ok(T::alias_expanded(id, Box::new(subst.clone())))
770        } else if let Some((id, defn)) = self.aliases_map.get_symbol(name) {
771            let locals = HashMap::new(); // Don't spill out the current scope
772            self.expand_defn(id, defn, locals, span)
773        } else {
774            Ok(T::identifier(name))
775        }
776    }
777
778    fn fold_function_call(
779        &mut self,
780        function: Box<FunctionCallNode<'i, T>>,
781        span: pest::Span<'i>,
782    ) -> Result<T, Self::Error> {
783        // For better error indication, builtin functions are shadowed by name,
784        // not by (name, arity).
785        if let Some(overloads) = self.aliases_map.get_function_overloads(function.name) {
786            // TODO: add support for keyword arguments
787            function
788                .ensure_no_keyword_arguments()
789                .map_err(E::invalid_arguments)?;
790            let Some((id, params, defn)) = overloads.find_by_arity(function.arity()) else {
791                let min = overloads.min_arity();
792                let max = overloads.max_arity();
793                let err = if max - min + 1 == overloads.arities().len() {
794                    function.invalid_arguments_count(min, Some(max))
795                } else {
796                    function.invalid_arguments_count_with_arities(overloads.arities())
797                };
798                return Err(E::invalid_arguments(err));
799            };
800            // Resolve arguments in the current scope, and pass them in to the alias
801            // expansion scope.
802            let args = fold_expression_nodes(self, function.args)?;
803            let locals = params.iter().map(|s| s.as_str()).zip(args).collect();
804            self.expand_defn(id, defn, locals, span)
805        } else {
806            let function = Box::new(fold_function_call_args(self, *function)?);
807            Ok(T::function_call(function))
808        }
809    }
810}
811
812/// Expands aliases recursively.
813pub fn expand_aliases<'i, T, P>(
814    node: ExpressionNode<'i, T>,
815    aliases_map: &'i AliasesMap<P, String>,
816) -> Result<ExpressionNode<'i, T>, P::Error>
817where
818    T: AliasExpandableExpression<'i> + Clone,
819    P: AliasDefinitionParser<Output<'i> = T>,
820    P::Error: AliasExpandError,
821{
822    let mut expander = AliasExpander {
823        aliases_map,
824        states: Vec::new(),
825    };
826    expander.fold_expression(node)
827}
828
829/// Collects similar names from the `candidates` list.
830pub fn collect_similar<I>(name: &str, candidates: I) -> Vec<String>
831where
832    I: IntoIterator,
833    I::Item: AsRef<str>,
834{
835    candidates
836        .into_iter()
837        .filter(|cand| {
838            // The parameter is borrowed from clap f5540d26
839            strsim::jaro(name, cand.as_ref()) > 0.7
840        })
841        .map(|s| s.as_ref().to_owned())
842        .sorted_unstable()
843        .collect()
844}
845
846#[cfg(test)]
847mod tests {
848    use super::*;
849
850    #[test]
851    fn test_expect_arguments() {
852        fn empty_span() -> pest::Span<'static> {
853            pest::Span::new("", 0, 0).unwrap()
854        }
855
856        fn function(
857            name: &'static str,
858            args: impl Into<Vec<ExpressionNode<'static, u32>>>,
859            keyword_args: impl Into<Vec<KeywordArgument<'static, u32>>>,
860        ) -> FunctionCallNode<'static, u32> {
861            FunctionCallNode {
862                name,
863                name_span: empty_span(),
864                args: args.into(),
865                keyword_args: keyword_args.into(),
866                args_span: empty_span(),
867            }
868        }
869
870        fn value(v: u32) -> ExpressionNode<'static, u32> {
871            ExpressionNode::new(v, empty_span())
872        }
873
874        fn keyword(name: &'static str, v: u32) -> KeywordArgument<'static, u32> {
875            KeywordArgument {
876                name,
877                name_span: empty_span(),
878                value: value(v),
879            }
880        }
881
882        let f = function("foo", [], []);
883        assert!(f.expect_no_arguments().is_ok());
884        assert!(f.expect_some_arguments::<0>().is_ok());
885        assert!(f.expect_arguments::<0, 0>().is_ok());
886        assert!(f.expect_named_arguments::<0, 0>(&[]).is_ok());
887
888        let f = function("foo", [value(0)], []);
889        assert!(f.expect_no_arguments().is_err());
890        assert_eq!(
891            f.expect_some_arguments::<0>().unwrap(),
892            (&[], [value(0)].as_slice())
893        );
894        assert_eq!(
895            f.expect_some_arguments::<1>().unwrap(),
896            (&[value(0)], [].as_slice())
897        );
898        assert!(f.expect_arguments::<0, 0>().is_err());
899        assert_eq!(
900            f.expect_arguments::<0, 1>().unwrap(),
901            (&[], [Some(&value(0))])
902        );
903        assert_eq!(f.expect_arguments::<1, 1>().unwrap(), (&[value(0)], [None]));
904        assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
905        assert_eq!(
906            f.expect_named_arguments::<0, 1>(&["a"]).unwrap(),
907            ([], [Some(&value(0))])
908        );
909        assert_eq!(
910            f.expect_named_arguments::<1, 0>(&["a"]).unwrap(),
911            ([&value(0)], [])
912        );
913
914        let f = function("foo", [], [keyword("a", 0)]);
915        assert!(f.expect_no_arguments().is_err());
916        assert!(f.expect_some_arguments::<1>().is_err());
917        assert!(f.expect_arguments::<0, 1>().is_err());
918        assert!(f.expect_arguments::<1, 0>().is_err());
919        assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
920        assert!(f.expect_named_arguments::<0, 1>(&[]).is_err());
921        assert!(f.expect_named_arguments::<1, 0>(&[]).is_err());
922        assert_eq!(
923            f.expect_named_arguments::<1, 0>(&["a"]).unwrap(),
924            ([&value(0)], [])
925        );
926        assert_eq!(
927            f.expect_named_arguments::<1, 1>(&["a", "b"]).unwrap(),
928            ([&value(0)], [None])
929        );
930        assert!(f.expect_named_arguments::<1, 1>(&["b", "a"]).is_err());
931
932        let f = function("foo", [value(0)], [keyword("a", 1), keyword("b", 2)]);
933        assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
934        assert!(f.expect_named_arguments::<1, 1>(&["a", "b"]).is_err());
935        assert_eq!(
936            f.expect_named_arguments::<1, 2>(&["c", "a", "b"]).unwrap(),
937            ([&value(0)], [Some(&value(1)), Some(&value(2))])
938        );
939        assert_eq!(
940            f.expect_named_arguments::<2, 1>(&["c", "b", "a"]).unwrap(),
941            ([&value(0), &value(2)], [Some(&value(1))])
942        );
943        assert_eq!(
944            f.expect_named_arguments::<0, 3>(&["c", "b", "a"]).unwrap(),
945            ([], [Some(&value(0)), Some(&value(2)), Some(&value(1))])
946        );
947
948        let f = function("foo", [], [keyword("a", 0), keyword("a", 1)]);
949        assert!(f.expect_named_arguments::<1, 1>(&["", "a"]).is_err());
950    }
951}