Skip to main content

jj_lib/
dsl_util.rs

1// Copyright 2020-2024 The Jujutsu Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Domain-specific language helpers.
16
17use std::ascii;
18use std::collections::HashMap;
19use std::fmt;
20use std::slice;
21
22use itertools::Itertools as _;
23use pest::RuleType;
24use pest::iterators::Pair;
25use pest::iterators::Pairs;
26
27/// Manages diagnostic messages emitted during parsing.
28///
29/// `T` is usually a parse error type of the language, which contains a message
30/// and source span of 'static lifetime.
31#[derive(Debug)]
32pub struct Diagnostics<T> {
33    // This might be extended to [{ kind: Warning|Error, message: T }, ..].
34    diagnostics: Vec<T>,
35}
36
37impl<T> Diagnostics<T> {
38    /// Creates new empty diagnostics collector.
39    pub fn new() -> Self {
40        Self {
41            diagnostics: Vec::new(),
42        }
43    }
44
45    /// Returns `true` if there are no diagnostic messages.
46    pub fn is_empty(&self) -> bool {
47        self.diagnostics.is_empty()
48    }
49
50    /// Returns the number of diagnostic messages.
51    pub fn len(&self) -> usize {
52        self.diagnostics.len()
53    }
54
55    /// Returns iterator over diagnostic messages.
56    pub fn iter(&self) -> slice::Iter<'_, T> {
57        self.diagnostics.iter()
58    }
59
60    /// Adds a diagnostic message of warning level.
61    pub fn add_warning(&mut self, diag: T) {
62        self.diagnostics.push(diag);
63    }
64
65    /// Moves diagnostic messages of different type (such as fileset warnings
66    /// emitted within `file()` revset.)
67    pub fn extend_with<U>(&mut self, diagnostics: Diagnostics<U>, mut f: impl FnMut(U) -> T) {
68        self.diagnostics
69            .extend(diagnostics.diagnostics.into_iter().map(&mut f));
70    }
71}
72
73impl<T> Default for Diagnostics<T> {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl<'a, T> IntoIterator for &'a Diagnostics<T> {
80    type Item = &'a T;
81    type IntoIter = slice::Iter<'a, T>;
82
83    fn into_iter(self) -> Self::IntoIter {
84        self.iter()
85    }
86}
87
88/// AST node without type or name checking.
89#[derive(Clone, Debug, Eq, PartialEq)]
90pub struct ExpressionNode<'i, T> {
91    /// Expression item such as identifier, literal, function call, etc.
92    pub kind: T,
93    /// Span of the node.
94    pub span: pest::Span<'i>,
95}
96
97impl<'i, T> ExpressionNode<'i, T> {
98    /// Wraps the given expression and span.
99    pub fn new(kind: T, span: pest::Span<'i>) -> Self {
100        Self { kind, span }
101    }
102}
103
104/// `<name>:<value>` expression in AST.
105#[derive(Clone, Debug, Eq, PartialEq)]
106pub struct PatternNode<'i, T> {
107    /// Pattern name or type (such as `glob`.)
108    pub name: &'i str,
109    /// Span of the pattern name.
110    pub name_span: pest::Span<'i>,
111    /// Value expression.
112    pub value: ExpressionNode<'i, T>,
113}
114
115/// Function call in AST.
116#[derive(Clone, Debug, Eq, PartialEq)]
117pub struct FunctionCallNode<'i, T> {
118    /// Function name.
119    pub name: &'i str,
120    /// Span of the function name.
121    pub name_span: pest::Span<'i>,
122    /// List of positional arguments.
123    pub args: Vec<ExpressionNode<'i, T>>,
124    /// List of keyword arguments.
125    pub keyword_args: Vec<KeywordArgument<'i, T>>,
126    /// Span of the arguments list.
127    pub args_span: pest::Span<'i>,
128}
129
130/// Keyword argument pair in AST.
131#[derive(Clone, Debug, Eq, PartialEq)]
132pub struct KeywordArgument<'i, T> {
133    /// Parameter name.
134    pub name: &'i str,
135    /// Span of the parameter name.
136    pub name_span: pest::Span<'i>,
137    /// Value expression.
138    pub value: ExpressionNode<'i, T>,
139}
140
141impl<'i, T> FunctionCallNode<'i, T> {
142    /// Number of arguments assuming named arguments are all unique.
143    pub fn arity(&self) -> usize {
144        self.args.len() + self.keyword_args.len()
145    }
146
147    /// Ensures that no arguments passed.
148    pub fn expect_no_arguments(&self) -> Result<(), InvalidArguments<'i>> {
149        let ([], []) = self.expect_arguments()?;
150        Ok(())
151    }
152
153    /// Extracts exactly N required arguments.
154    pub fn expect_exact_arguments<const N: usize>(
155        &self,
156    ) -> Result<&[ExpressionNode<'i, T>; N], InvalidArguments<'i>> {
157        let (args, []) = self.expect_arguments()?;
158        Ok(args)
159    }
160
161    /// Extracts N required arguments and remainders.
162    ///
163    /// This can be used to get all the positional arguments without requiring
164    /// any (N = 0):
165    /// ```ignore
166    /// let ([], content_nodes) = function.expect_some_arguments()?;
167    /// ```
168    /// Avoid accessing `function.args` directly, as that may allow keyword
169    /// arguments to be silently ignored.
170    #[expect(clippy::type_complexity)]
171    pub fn expect_some_arguments<const N: usize>(
172        &self,
173    ) -> Result<(&[ExpressionNode<'i, T>; N], &[ExpressionNode<'i, T>]), InvalidArguments<'i>> {
174        self.ensure_no_keyword_arguments()?;
175        if self.args.len() >= N {
176            let (required, rest) = self.args.split_at(N);
177            Ok((required.try_into().unwrap(), rest))
178        } else {
179            Err(self.invalid_arguments_count(N, None))
180        }
181    }
182
183    /// Extracts N required arguments and M optional arguments.
184    #[expect(clippy::type_complexity)]
185    pub fn expect_arguments<const N: usize, const M: usize>(
186        &self,
187    ) -> Result<
188        (
189            &[ExpressionNode<'i, T>; N],
190            [Option<&ExpressionNode<'i, T>>; M],
191        ),
192        InvalidArguments<'i>,
193    > {
194        self.ensure_no_keyword_arguments()?;
195        let count_range = N..=(N + M);
196        if count_range.contains(&self.args.len()) {
197            let (required, rest) = self.args.split_at(N);
198            let mut optional = rest.iter().map(Some).collect_vec();
199            optional.resize(M, None);
200            Ok((
201                required.try_into().unwrap(),
202                optional.try_into().ok().unwrap(),
203            ))
204        } else {
205            let (min, max) = count_range.into_inner();
206            Err(self.invalid_arguments_count(min, Some(max)))
207        }
208    }
209
210    /// Extracts N required arguments and M optional arguments. Some of them can
211    /// be specified as keyword arguments.
212    ///
213    /// `names` is a list of parameter names. Unnamed positional arguments
214    /// should be padded with `""`.
215    #[expect(clippy::type_complexity)]
216    pub fn expect_named_arguments<const N: usize, const M: usize>(
217        &self,
218        names: &[&str],
219    ) -> Result<
220        (
221            [&ExpressionNode<'i, T>; N],
222            [Option<&ExpressionNode<'i, T>>; M],
223        ),
224        InvalidArguments<'i>,
225    > {
226        if self.keyword_args.is_empty() {
227            let (required, optional) = self.expect_arguments::<N, M>()?;
228            Ok((required.each_ref(), optional))
229        } else {
230            let (required, optional) = self.expect_named_arguments_vec(names, N, N + M)?;
231            Ok((
232                required.try_into().ok().unwrap(),
233                optional.try_into().ok().unwrap(),
234            ))
235        }
236    }
237
238    #[expect(clippy::type_complexity)]
239    fn expect_named_arguments_vec(
240        &self,
241        names: &[&str],
242        min: usize,
243        max: usize,
244    ) -> Result<
245        (
246            Vec<&ExpressionNode<'i, T>>,
247            Vec<Option<&ExpressionNode<'i, T>>>,
248        ),
249        InvalidArguments<'i>,
250    > {
251        assert!(names.len() <= max);
252
253        if self.args.len() > max {
254            return Err(self.invalid_arguments_count(min, Some(max)));
255        }
256        let mut extracted = Vec::with_capacity(max);
257        extracted.extend(self.args.iter().map(Some));
258        extracted.resize(max, None);
259
260        for arg in &self.keyword_args {
261            let name = arg.name;
262            let span = arg.name_span.start_pos().span(&arg.value.span.end_pos());
263            let pos = names.iter().position(|&n| n == name).ok_or_else(|| {
264                self.invalid_arguments(format!(r#"Unexpected keyword argument "{name}""#), span)
265            })?;
266            if extracted[pos].is_some() {
267                return Err(self.invalid_arguments(
268                    format!(r#"Got multiple values for keyword "{name}""#),
269                    span,
270                ));
271            }
272            extracted[pos] = Some(&arg.value);
273        }
274
275        let optional = extracted.split_off(min);
276        let required = extracted.into_iter().flatten().collect_vec();
277        if required.len() != min {
278            return Err(self.invalid_arguments_count(min, Some(max)));
279        }
280        Ok((required, optional))
281    }
282
283    fn ensure_no_keyword_arguments(&self) -> Result<(), InvalidArguments<'i>> {
284        if let (Some(first), Some(last)) = (self.keyword_args.first(), self.keyword_args.last()) {
285            let span = first.name_span.start_pos().span(&last.value.span.end_pos());
286            Err(self.invalid_arguments("Unexpected keyword arguments".to_owned(), span))
287        } else {
288            Ok(())
289        }
290    }
291
292    fn invalid_arguments(&self, message: String, span: pest::Span<'i>) -> InvalidArguments<'i> {
293        InvalidArguments {
294            name: self.name,
295            message,
296            span,
297        }
298    }
299
300    fn invalid_arguments_count(&self, min: usize, max: Option<usize>) -> InvalidArguments<'i> {
301        let message = match (min, max) {
302            (min, Some(max)) if min == max => format!("Expected {min} arguments"),
303            (min, Some(max)) => format!("Expected {min} to {max} arguments"),
304            (min, None) => format!("Expected at least {min} arguments"),
305        };
306        self.invalid_arguments(message, self.args_span)
307    }
308
309    fn invalid_arguments_count_with_arities(
310        &self,
311        arities: impl IntoIterator<Item = usize>,
312    ) -> InvalidArguments<'i> {
313        let message = format!("Expected {} arguments", arities.into_iter().join(", "));
314        self.invalid_arguments(message, self.args_span)
315    }
316}
317
318/// Unexpected number of arguments, or invalid combination of arguments.
319///
320/// This error is supposed to be converted to language-specific parse error
321/// type, where lifetime `'i` will be eliminated.
322#[derive(Clone, Debug)]
323pub struct InvalidArguments<'i> {
324    /// Function name.
325    pub name: &'i str,
326    /// Error message.
327    pub message: String,
328    /// Span of the bad arguments.
329    pub span: pest::Span<'i>,
330}
331
332/// Expression item that can be transformed recursively by using `folder: F`.
333pub trait FoldableExpression<'i>: Sized {
334    /// Transforms `self` by applying the `folder` to inner items.
335    fn fold<F>(self, folder: &mut F, span: pest::Span<'i>) -> Result<Self, F::Error>
336    where
337        F: ExpressionFolder<'i, Self> + ?Sized;
338}
339
340/// Visitor-like interface to transform AST nodes recursively.
341pub trait ExpressionFolder<'i, T: FoldableExpression<'i>> {
342    /// Transform error.
343    type Error;
344
345    /// Transforms the expression `node`. By default, inner items are
346    /// transformed recursively.
347    fn fold_expression(
348        &mut self,
349        node: ExpressionNode<'i, T>,
350    ) -> Result<ExpressionNode<'i, T>, Self::Error> {
351        let ExpressionNode { kind, span } = node;
352        let kind = kind.fold(self, span)?;
353        Ok(ExpressionNode { kind, span })
354    }
355
356    /// Transforms identifier.
357    fn fold_identifier(&mut self, name: &'i str, span: pest::Span<'i>) -> Result<T, Self::Error>;
358
359    /// Transforms pattern.
360    fn fold_pattern(
361        &mut self,
362        pattern: Box<PatternNode<'i, T>>,
363        span: pest::Span<'i>,
364    ) -> Result<T, Self::Error>;
365
366    /// Transforms function call.
367    fn fold_function_call(
368        &mut self,
369        function: Box<FunctionCallNode<'i, T>>,
370        span: pest::Span<'i>,
371    ) -> Result<T, Self::Error>;
372}
373
374/// Transforms list of `nodes` by using `folder`.
375pub fn fold_expression_nodes<'i, F, T>(
376    folder: &mut F,
377    nodes: Vec<ExpressionNode<'i, T>>,
378) -> Result<Vec<ExpressionNode<'i, T>>, F::Error>
379where
380    F: ExpressionFolder<'i, T> + ?Sized,
381    T: FoldableExpression<'i>,
382{
383    nodes
384        .into_iter()
385        .map(|node| folder.fold_expression(node))
386        .try_collect()
387}
388
389/// Transforms pattern value by using `folder`.
390pub fn fold_pattern_value<'i, F, T>(
391    folder: &mut F,
392    pattern: PatternNode<'i, T>,
393) -> Result<PatternNode<'i, T>, F::Error>
394where
395    F: ExpressionFolder<'i, T> + ?Sized,
396    T: FoldableExpression<'i>,
397{
398    Ok(PatternNode {
399        name: pattern.name,
400        name_span: pattern.name_span,
401        value: folder.fold_expression(pattern.value)?,
402    })
403}
404
405/// Transforms function call arguments by using `folder`.
406pub fn fold_function_call_args<'i, F, T>(
407    folder: &mut F,
408    function: FunctionCallNode<'i, T>,
409) -> Result<FunctionCallNode<'i, T>, F::Error>
410where
411    F: ExpressionFolder<'i, T> + ?Sized,
412    T: FoldableExpression<'i>,
413{
414    Ok(FunctionCallNode {
415        name: function.name,
416        name_span: function.name_span,
417        args: fold_expression_nodes(folder, function.args)?,
418        keyword_args: function
419            .keyword_args
420            .into_iter()
421            .map(|arg| {
422                Ok(KeywordArgument {
423                    name: arg.name,
424                    name_span: arg.name_span,
425                    value: folder.fold_expression(arg.value)?,
426                })
427            })
428            .try_collect()?,
429        args_span: function.args_span,
430    })
431}
432
433/// Helper to parse string literal.
434#[derive(Debug)]
435pub struct StringLiteralParser<R> {
436    /// String content part.
437    pub content_rule: R,
438    /// Escape sequence part including backslash character.
439    pub escape_rule: R,
440}
441
442impl<R: RuleType> StringLiteralParser<R> {
443    /// Parses the given string literal `pairs` into string.
444    pub fn parse(&self, pairs: Pairs<R>) -> String {
445        let mut result = String::new();
446        for part in pairs {
447            if part.as_rule() == self.content_rule {
448                result.push_str(part.as_str());
449            } else if part.as_rule() == self.escape_rule {
450                match &part.as_str()[1..] {
451                    "\"" => result.push('"'),
452                    "\\" => result.push('\\'),
453                    "t" => result.push('\t'),
454                    "r" => result.push('\r'),
455                    "n" => result.push('\n'),
456                    "0" => result.push('\0'),
457                    "e" => result.push('\x1b'),
458                    hex if hex.starts_with('x') => {
459                        result.push(char::from(
460                            u8::from_str_radix(&hex[1..], 16).expect("hex characters"),
461                        ));
462                    }
463                    char => panic!("invalid escape: \\{char:?}"),
464                }
465            } else {
466                panic!("unexpected part of string: {part:?}");
467            }
468        }
469        result
470    }
471}
472
473/// Escape special characters in the input
474pub fn escape_string(unescaped: &str) -> String {
475    let mut escaped = String::with_capacity(unescaped.len());
476    for c in unescaped.chars() {
477        match c {
478            '"' => escaped.push_str(r#"\""#),
479            '\\' => escaped.push_str(r#"\\"#),
480            '\t' => escaped.push_str(r#"\t"#),
481            '\r' => escaped.push_str(r#"\r"#),
482            '\n' => escaped.push_str(r#"\n"#),
483            '\0' => escaped.push_str(r#"\0"#),
484            c if c.is_ascii_control() => {
485                for b in ascii::escape_default(c as u8) {
486                    escaped.push(b as char);
487                }
488            }
489            c => escaped.push(c),
490        }
491    }
492    escaped
493}
494
495/// Helper to parse function call.
496#[derive(Debug)]
497pub struct FunctionCallParser<R> {
498    /// Function name.
499    pub function_name_rule: R,
500    /// List of positional and keyword arguments.
501    pub function_arguments_rule: R,
502    /// Pair of parameter name and value.
503    pub keyword_argument_rule: R,
504    /// Parameter name.
505    pub argument_name_rule: R,
506    /// Value expression.
507    pub argument_value_rule: R,
508}
509
510impl<R: RuleType> FunctionCallParser<R> {
511    /// Parses the given `pair` as function call.
512    pub fn parse<'i, T, E: From<InvalidArguments<'i>>>(
513        &self,
514        pair: Pair<'i, R>,
515        // parse_name can be defined for any Pair<'_, R>, but parse_value should
516        // be allowed to construct T by capturing Pair<'i, R>.
517        parse_name: impl Fn(Pair<'i, R>) -> Result<&'i str, E>,
518        parse_value: impl Fn(Pair<'i, R>) -> Result<ExpressionNode<'i, T>, E>,
519    ) -> Result<FunctionCallNode<'i, T>, E> {
520        let [name_pair, args_pair] = pair.into_inner().collect_array().unwrap();
521        assert_eq!(name_pair.as_rule(), self.function_name_rule);
522        assert_eq!(args_pair.as_rule(), self.function_arguments_rule);
523        let name_span = name_pair.as_span();
524        let args_span = args_pair.as_span();
525        let function_name = parse_name(name_pair)?;
526        let mut args = Vec::new();
527        let mut keyword_args = Vec::new();
528        for pair in args_pair.into_inner() {
529            let span = pair.as_span();
530            if pair.as_rule() == self.argument_value_rule {
531                if !keyword_args.is_empty() {
532                    return Err(InvalidArguments {
533                        name: function_name,
534                        message: "Positional argument follows keyword argument".to_owned(),
535                        span,
536                    }
537                    .into());
538                }
539                args.push(parse_value(pair)?);
540            } else if pair.as_rule() == self.keyword_argument_rule {
541                let [name_pair, value_pair] = pair.into_inner().collect_array().unwrap();
542                assert_eq!(name_pair.as_rule(), self.argument_name_rule);
543                assert_eq!(value_pair.as_rule(), self.argument_value_rule);
544                let name_span = name_pair.as_span();
545                let arg = KeywordArgument {
546                    name: parse_name(name_pair)?,
547                    name_span,
548                    value: parse_value(value_pair)?,
549                };
550                keyword_args.push(arg);
551            } else {
552                panic!("unexpected argument rule {pair:?}");
553            }
554        }
555        Ok(FunctionCallNode {
556            name: function_name,
557            name_span,
558            args,
559            keyword_args,
560            args_span,
561        })
562    }
563}
564
565/// A function alias containing `(params, definition, description)`.
566type FunctionAlias<V> = (Vec<String>, V, Option<String>);
567
568/// Map of symbol, pattern, and function aliases.
569#[derive(Clone, Debug, Default)]
570pub struct AliasesMap<P, V> {
571    symbol_aliases: HashMap<String, (V, Option<String>)>,
572    // name: (param, defn)
573    pattern_aliases: HashMap<String, (String, V, Option<String>)>,
574    // name: [(params, defn)] (sorted by arity)
575    function_aliases: HashMap<String, Vec<FunctionAlias<V>>>,
576    // Parser type P helps prevent misuse of AliasesMap of different language.
577    parser: P,
578}
579
580impl<P, V> AliasesMap<P, V> {
581    /// Creates an empty aliases map with default-constructed parser.
582    pub fn new() -> Self
583    where
584        P: Default,
585    {
586        Self {
587            symbol_aliases: Default::default(),
588            pattern_aliases: Default::default(),
589            function_aliases: Default::default(),
590            parser: Default::default(),
591        }
592    }
593
594    /// Adds new substitution rule `decl = defn`.
595    ///
596    /// Returns error if `decl` is invalid. The `defn` part isn't checked. A bad
597    /// `defn` will be reported when the alias is substituted.
598    pub fn insert(
599        &mut self,
600        decl: impl AsRef<str>,
601        defn: impl Into<V>,
602        doc: Option<String>,
603    ) -> Result<(), P::Error>
604    where
605        P: AliasDeclarationParser,
606    {
607        match self.parser.parse_declaration(decl.as_ref())? {
608            AliasDeclaration::Symbol(name) => {
609                self.symbol_aliases.insert(name, (defn.into(), doc));
610            }
611            AliasDeclaration::Pattern(name, param) => {
612                self.pattern_aliases.insert(name, (param, defn.into(), doc));
613            }
614            AliasDeclaration::Function(name, params) => {
615                let overloads = self.function_aliases.entry(name).or_default();
616                match overloads.binary_search_by_key(&params.len(), |(params, _, _)| params.len()) {
617                    Ok(i) => overloads[i] = (params, defn.into(), doc),
618                    Err(i) => overloads.insert(i, (params, defn.into(), doc)),
619                }
620            }
621        }
622        Ok(())
623    }
624
625    /// Iterates symbol names in arbitrary order.
626    pub fn symbol_names(&self) -> impl Iterator<Item = &str> {
627        self.symbol_aliases.keys().map(|n| n.as_ref())
628    }
629
630    /// Iterates pattern names in arbitrary order.
631    pub fn pattern_names(&self) -> impl Iterator<Item = &str> {
632        self.pattern_aliases.keys().map(|n| n.as_ref())
633    }
634
635    /// Iterates function names in arbitrary order.
636    pub fn function_names(&self) -> impl Iterator<Item = &str> {
637        self.function_aliases.keys().map(|n| n.as_ref())
638    }
639
640    /// Looks up symbol alias by name. Returns identifier, definition text, and
641    /// optional description.
642    pub fn get_symbol(&self, name: &str) -> Option<(AliasId<'_>, &V, Option<&str>)> {
643        self.symbol_aliases
644            .get_key_value(name)
645            .map(|(name, (defn, doc))| (AliasId::Symbol(name), defn, doc.as_deref()))
646    }
647
648    /// Looks up pattern alias by name. Returns identifier, parameter name,
649    /// definition text, and optional description.
650    pub fn get_pattern(&self, name: &str) -> Option<(AliasId<'_>, &str, &V, Option<&str>)> {
651        self.pattern_aliases
652            .get_key_value(name)
653            .map(|(name, (param, defn, doc))| {
654                (
655                    AliasId::Pattern(name, param),
656                    param.as_ref(),
657                    defn,
658                    doc.as_deref(),
659                )
660            })
661    }
662
663    /// Looks up function alias by name and arity. Returns identifier, list of
664    /// parameter names, definition text, and optional description.
665    pub fn get_function(
666        &self,
667        name: &str,
668        arity: usize,
669    ) -> Option<(AliasId<'_>, &[String], &V, Option<&str>)> {
670        let overloads = self.get_function_overloads(name)?;
671        overloads.find_by_arity(arity)
672    }
673
674    /// Looks up function aliases by name.
675    fn get_function_overloads(&self, name: &str) -> Option<AliasFunctionOverloads<'_, V>> {
676        let (name, overloads) = self.function_aliases.get_key_value(name)?;
677        Some(AliasFunctionOverloads { name, overloads })
678    }
679}
680
681#[derive(Clone, Debug)]
682struct AliasFunctionOverloads<'a, V> {
683    name: &'a String,
684    overloads: &'a Vec<(Vec<String>, V, Option<String>)>,
685}
686
687impl<'a, V> AliasFunctionOverloads<'a, V> {
688    fn arities(&self) -> impl DoubleEndedIterator<Item = usize> + ExactSizeIterator {
689        self.overloads.iter().map(|(params, _, _)| params.len())
690    }
691
692    fn min_arity(&self) -> usize {
693        self.arities().next().unwrap()
694    }
695
696    fn max_arity(&self) -> usize {
697        self.arities().next_back().unwrap()
698    }
699
700    fn find_by_arity(
701        &self,
702        arity: usize,
703    ) -> Option<(AliasId<'a>, &'a [String], &'a V, Option<&'a str>)> {
704        let index = self
705            .overloads
706            .binary_search_by_key(&arity, |(params, _, _)| params.len())
707            .ok()?;
708        let (params, defn, doc) = &self.overloads[index];
709        // Exact parameter names aren't needed to identify a function, but they
710        // provide a better error indication. (e.g. "foo(x, y)" is easier to
711        // follow than "foo/2".)
712        Some((
713            AliasId::Function(self.name, params),
714            params,
715            defn,
716            doc.as_deref(),
717        ))
718    }
719}
720
721/// Borrowed reference to identify alias expression.
722#[derive(Clone, Copy, Debug, Eq, PartialEq)]
723pub enum AliasId<'a> {
724    /// Symbol name.
725    Symbol(&'a str),
726    /// Pattern name and parameter name.
727    Pattern(&'a str, &'a str),
728    /// Function name and parameter names.
729    Function(&'a str, &'a [String]),
730    /// Function parameter name.
731    Parameter(&'a str),
732}
733
734impl fmt::Display for AliasId<'_> {
735    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
736        match self {
737            Self::Symbol(name) => write!(f, "{name}"),
738            Self::Pattern(name, param) => write!(f, "{name}:{param}"),
739            Self::Function(name, params) => {
740                write!(f, "{name}({params})", params = params.join(", "))
741            }
742            Self::Parameter(name) => write!(f, "{name}"),
743        }
744    }
745}
746
747/// Parsed declaration part of alias rule.
748#[derive(Clone, Debug)]
749pub enum AliasDeclaration {
750    /// Symbol name.
751    Symbol(String),
752    /// Pattern name and parameter.
753    Pattern(String, String),
754    /// Function name and parameters.
755    Function(String, Vec<String>),
756}
757
758// AliasDeclarationParser and AliasDefinitionParser can be merged into a single
759// trait, but it's unclear whether doing that would simplify the abstraction.
760
761/// Parser for symbol and function alias declaration.
762pub trait AliasDeclarationParser {
763    /// Parse error type.
764    type Error;
765
766    /// Parses symbol or function name and parameters.
767    fn parse_declaration(&self, source: &str) -> Result<AliasDeclaration, Self::Error>;
768}
769
770/// Parser for symbol and function alias definition.
771pub trait AliasDefinitionParser {
772    /// Expression item type.
773    type Output<'i>;
774    /// Parse error type.
775    type Error;
776
777    /// Parses alias body.
778    fn parse_definition<'i>(
779        &self,
780        source: &'i str,
781    ) -> Result<ExpressionNode<'i, Self::Output<'i>>, Self::Error>;
782}
783
784/// Expression item that supports alias substitution.
785pub trait AliasExpandableExpression<'i>: FoldableExpression<'i> {
786    /// Wraps identifier.
787    fn identifier(name: &'i str) -> Self;
788    /// Wraps pattern.
789    fn pattern(pattern: Box<PatternNode<'i, Self>>) -> Self;
790    /// Wraps function call.
791    fn function_call(function: Box<FunctionCallNode<'i, Self>>) -> Self;
792    /// Wraps substituted expression.
793    fn alias_expanded(id: AliasId<'i>, subst: Box<ExpressionNode<'i, Self>>) -> Self;
794}
795
796/// Error that may occur during alias substitution.
797pub trait AliasExpandError: Sized {
798    /// Unexpected number of arguments, or invalid combination of arguments.
799    fn invalid_arguments(err: InvalidArguments<'_>) -> Self;
800    /// Recursion detected during alias substitution.
801    fn recursive_expansion(id: AliasId<'_>, span: pest::Span<'_>) -> Self;
802    /// Attaches alias trace to the current error.
803    fn within_alias_expansion(self, id: AliasId<'_>, span: pest::Span<'_>) -> Self;
804}
805
806/// Expands aliases recursively in tree of `T`.
807#[derive(Debug)]
808struct AliasExpander<'i, 'a, T, P> {
809    /// Alias symbols and functions that are globally available.
810    aliases_map: &'i AliasesMap<P, String>,
811    /// Local variables set in the outermost scope.
812    locals: &'a HashMap<&'i str, ExpressionNode<'i, T>>,
813    /// Stack of aliases and local parameters currently expanding.
814    states: Vec<AliasExpandingState<'i, T>>,
815}
816
817#[derive(Debug)]
818struct AliasExpandingState<'i, T> {
819    id: AliasId<'i>,
820    locals: HashMap<&'i str, ExpressionNode<'i, T>>,
821}
822
823impl<'i, T, P, E> AliasExpander<'i, '_, T, P>
824where
825    T: AliasExpandableExpression<'i> + Clone,
826    P: AliasDefinitionParser<Output<'i> = T, Error = E>,
827    E: AliasExpandError,
828{
829    /// Local variables available to the current scope.
830    fn current_locals(&self) -> &HashMap<&'i str, ExpressionNode<'i, T>> {
831        self.states.last().map_or(self.locals, |s| &s.locals)
832    }
833
834    fn expand_defn(
835        &mut self,
836        id: AliasId<'i>,
837        defn: &'i str,
838        locals: HashMap<&'i str, ExpressionNode<'i, T>>,
839        span: pest::Span<'i>,
840    ) -> Result<T, E> {
841        // The stack should be short, so let's simply do linear search.
842        if self.states.iter().any(|s| s.id == id) {
843            return Err(E::recursive_expansion(id, span));
844        }
845        self.states.push(AliasExpandingState { id, locals });
846        // Parsed defn could be cached if needed.
847        let result = self
848            .aliases_map
849            .parser
850            .parse_definition(defn)
851            .and_then(|node| self.fold_expression(node))
852            .map(|node| T::alias_expanded(id, Box::new(node)))
853            .map_err(|e| e.within_alias_expansion(id, span));
854        self.states.pop();
855        result
856    }
857}
858
859impl<'i, T, P, E> ExpressionFolder<'i, T> for AliasExpander<'i, '_, T, P>
860where
861    T: AliasExpandableExpression<'i> + Clone,
862    P: AliasDefinitionParser<Output<'i> = T, Error = E>,
863    E: AliasExpandError,
864{
865    type Error = E;
866
867    fn fold_identifier(&mut self, name: &'i str, span: pest::Span<'i>) -> Result<T, Self::Error> {
868        if let Some(subst) = self.current_locals().get(name) {
869            let id = AliasId::Parameter(name);
870            Ok(T::alias_expanded(id, Box::new(subst.clone())))
871        } else if let Some((id, defn, _doc)) = self.aliases_map.get_symbol(name) {
872            let locals = HashMap::new(); // Don't spill out the current scope
873            self.expand_defn(id, defn, locals, span)
874        } else {
875            Ok(T::identifier(name))
876        }
877    }
878
879    fn fold_pattern(
880        &mut self,
881        pattern: Box<PatternNode<'i, T>>,
882        span: pest::Span<'i>,
883    ) -> Result<T, Self::Error> {
884        if let Some((id, param, defn, _doc)) = self.aliases_map.get_pattern(pattern.name) {
885            // Resolve argument in the current scope, and pass it in to the
886            // alias expansion scope.
887            let arg = self.fold_expression(pattern.value)?;
888            let locals = HashMap::from([(param, arg)]);
889            self.expand_defn(id, defn, locals, span)
890        } else {
891            let pattern = Box::new(fold_pattern_value(self, *pattern)?);
892            Ok(T::pattern(pattern))
893        }
894    }
895
896    fn fold_function_call(
897        &mut self,
898        function: Box<FunctionCallNode<'i, T>>,
899        span: pest::Span<'i>,
900    ) -> Result<T, Self::Error> {
901        // For better error indication, builtin functions are shadowed by name,
902        // not by (name, arity).
903        if let Some(overloads) = self.aliases_map.get_function_overloads(function.name) {
904            // TODO: add support for keyword arguments
905            function
906                .ensure_no_keyword_arguments()
907                .map_err(E::invalid_arguments)?;
908            let Some((id, params, defn, _doc)) = overloads.find_by_arity(function.arity()) else {
909                let min = overloads.min_arity();
910                let max = overloads.max_arity();
911                let err = if max - min + 1 == overloads.arities().len() {
912                    function.invalid_arguments_count(min, Some(max))
913                } else {
914                    function.invalid_arguments_count_with_arities(overloads.arities())
915                };
916                return Err(E::invalid_arguments(err));
917            };
918            // Resolve arguments in the current scope, and pass them in to the alias
919            // expansion scope.
920            let args = fold_expression_nodes(self, function.args)?;
921            let locals = params.iter().map(|s| s.as_str()).zip(args).collect();
922            self.expand_defn(id, defn, locals, span)
923        } else {
924            let function = Box::new(fold_function_call_args(self, *function)?);
925            Ok(T::function_call(function))
926        }
927    }
928}
929
930/// Expands aliases recursively.
931pub fn expand_aliases<'i, T, P>(
932    node: ExpressionNode<'i, T>,
933    aliases_map: &'i AliasesMap<P, String>,
934) -> Result<ExpressionNode<'i, T>, P::Error>
935where
936    T: AliasExpandableExpression<'i> + Clone,
937    P: AliasDefinitionParser<Output<'i> = T>,
938    P::Error: AliasExpandError,
939{
940    expand_aliases_with_locals(node, aliases_map, &HashMap::new())
941}
942
943/// Expands aliases recursively with the outermost local variables.
944///
945/// Local variables are similar to alias symbols, but are scoped. Alias symbols
946/// are globally accessible from alias expressions, but local variables aren't.
947pub fn expand_aliases_with_locals<'i, T, P>(
948    node: ExpressionNode<'i, T>,
949    aliases_map: &'i AliasesMap<P, String>,
950    locals: &HashMap<&'i str, ExpressionNode<'i, T>>,
951) -> Result<ExpressionNode<'i, T>, P::Error>
952where
953    T: AliasExpandableExpression<'i> + Clone,
954    P: AliasDefinitionParser<Output<'i> = T>,
955    P::Error: AliasExpandError,
956{
957    let mut expander = AliasExpander {
958        aliases_map,
959        locals,
960        states: Vec::new(),
961    };
962    expander.fold_expression(node)
963}
964
965/// Collects similar names from the `candidates` list.
966pub fn collect_similar<I>(name: &str, candidates: I) -> Vec<String>
967where
968    I: IntoIterator,
969    I::Item: AsRef<str>,
970{
971    candidates
972        .into_iter()
973        .filter(|cand| {
974            // The parameter is borrowed from clap f5540d26
975            strsim::jaro(name, cand.as_ref()) > 0.7
976        })
977        .map(|s| s.as_ref().to_owned())
978        .sorted_unstable()
979        .collect()
980}
981
982#[cfg(test)]
983mod tests {
984    use super::*;
985
986    #[test]
987    fn test_expect_arguments() {
988        fn empty_span() -> pest::Span<'static> {
989            pest::Span::new("", 0, 0).unwrap()
990        }
991
992        fn function(
993            name: &'static str,
994            args: impl Into<Vec<ExpressionNode<'static, u32>>>,
995            keyword_args: impl Into<Vec<KeywordArgument<'static, u32>>>,
996        ) -> FunctionCallNode<'static, u32> {
997            FunctionCallNode {
998                name,
999                name_span: empty_span(),
1000                args: args.into(),
1001                keyword_args: keyword_args.into(),
1002                args_span: empty_span(),
1003            }
1004        }
1005
1006        fn value(v: u32) -> ExpressionNode<'static, u32> {
1007            ExpressionNode::new(v, empty_span())
1008        }
1009
1010        fn keyword(name: &'static str, v: u32) -> KeywordArgument<'static, u32> {
1011            KeywordArgument {
1012                name,
1013                name_span: empty_span(),
1014                value: value(v),
1015            }
1016        }
1017
1018        let f = function("foo", [], []);
1019        assert!(f.expect_no_arguments().is_ok());
1020        assert!(f.expect_some_arguments::<0>().is_ok());
1021        assert!(f.expect_arguments::<0, 0>().is_ok());
1022        assert!(f.expect_named_arguments::<0, 0>(&[]).is_ok());
1023
1024        let f = function("foo", [value(0)], []);
1025        assert!(f.expect_no_arguments().is_err());
1026        assert_eq!(
1027            f.expect_some_arguments::<0>().unwrap(),
1028            (&[], [value(0)].as_slice())
1029        );
1030        assert_eq!(
1031            f.expect_some_arguments::<1>().unwrap(),
1032            (&[value(0)], [].as_slice())
1033        );
1034        assert!(f.expect_arguments::<0, 0>().is_err());
1035        assert_eq!(
1036            f.expect_arguments::<0, 1>().unwrap(),
1037            (&[], [Some(&value(0))])
1038        );
1039        assert_eq!(f.expect_arguments::<1, 1>().unwrap(), (&[value(0)], [None]));
1040        assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
1041        assert_eq!(
1042            f.expect_named_arguments::<0, 1>(&["a"]).unwrap(),
1043            ([], [Some(&value(0))])
1044        );
1045        assert_eq!(
1046            f.expect_named_arguments::<1, 0>(&["a"]).unwrap(),
1047            ([&value(0)], [])
1048        );
1049
1050        let f = function("foo", [], [keyword("a", 0)]);
1051        assert!(f.expect_no_arguments().is_err());
1052        assert!(f.expect_some_arguments::<1>().is_err());
1053        assert!(f.expect_arguments::<0, 1>().is_err());
1054        assert!(f.expect_arguments::<1, 0>().is_err());
1055        assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
1056        assert!(f.expect_named_arguments::<0, 1>(&[]).is_err());
1057        assert!(f.expect_named_arguments::<1, 0>(&[]).is_err());
1058        assert_eq!(
1059            f.expect_named_arguments::<1, 0>(&["a"]).unwrap(),
1060            ([&value(0)], [])
1061        );
1062        assert_eq!(
1063            f.expect_named_arguments::<1, 1>(&["a", "b"]).unwrap(),
1064            ([&value(0)], [None])
1065        );
1066        assert!(f.expect_named_arguments::<1, 1>(&["b", "a"]).is_err());
1067
1068        let f = function("foo", [value(0)], [keyword("a", 1), keyword("b", 2)]);
1069        assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
1070        assert!(f.expect_named_arguments::<1, 1>(&["a", "b"]).is_err());
1071        assert_eq!(
1072            f.expect_named_arguments::<1, 2>(&["c", "a", "b"]).unwrap(),
1073            ([&value(0)], [Some(&value(1)), Some(&value(2))])
1074        );
1075        assert_eq!(
1076            f.expect_named_arguments::<2, 1>(&["c", "b", "a"]).unwrap(),
1077            ([&value(0), &value(2)], [Some(&value(1))])
1078        );
1079        assert_eq!(
1080            f.expect_named_arguments::<0, 3>(&["c", "b", "a"]).unwrap(),
1081            ([], [Some(&value(0)), Some(&value(2)), Some(&value(1))])
1082        );
1083
1084        let f = function("foo", [], [keyword("a", 0), keyword("a", 1)]);
1085        assert!(f.expect_named_arguments::<1, 1>(&["", "a"]).is_err());
1086    }
1087}