1use std::ascii;
18use std::collections::HashMap;
19use std::fmt;
20use std::slice;
21
22use itertools::Itertools as _;
23use pest::RuleType;
24use pest::iterators::Pair;
25use pest::iterators::Pairs;
26
27#[derive(Debug)]
32pub struct Diagnostics<T> {
33 diagnostics: Vec<T>,
35}
36
37impl<T> Diagnostics<T> {
38 pub fn new() -> Self {
40 Self {
41 diagnostics: Vec::new(),
42 }
43 }
44
45 pub fn is_empty(&self) -> bool {
47 self.diagnostics.is_empty()
48 }
49
50 pub fn len(&self) -> usize {
52 self.diagnostics.len()
53 }
54
55 pub fn iter(&self) -> slice::Iter<'_, T> {
57 self.diagnostics.iter()
58 }
59
60 pub fn add_warning(&mut self, diag: T) {
62 self.diagnostics.push(diag);
63 }
64
65 pub fn extend_with<U>(&mut self, diagnostics: Diagnostics<U>, mut f: impl FnMut(U) -> T) {
68 self.diagnostics
69 .extend(diagnostics.diagnostics.into_iter().map(&mut f));
70 }
71}
72
73impl<T> Default for Diagnostics<T> {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl<'a, T> IntoIterator for &'a Diagnostics<T> {
80 type Item = &'a T;
81 type IntoIter = slice::Iter<'a, T>;
82
83 fn into_iter(self) -> Self::IntoIter {
84 self.iter()
85 }
86}
87
88#[derive(Clone, Debug, Eq, PartialEq)]
90pub struct ExpressionNode<'i, T> {
91 pub kind: T,
93 pub span: pest::Span<'i>,
95}
96
97impl<'i, T> ExpressionNode<'i, T> {
98 pub fn new(kind: T, span: pest::Span<'i>) -> Self {
100 Self { kind, span }
101 }
102}
103
104#[derive(Clone, Debug, Eq, PartialEq)]
106pub struct PatternNode<'i, T> {
107 pub name: &'i str,
109 pub name_span: pest::Span<'i>,
111 pub value: ExpressionNode<'i, T>,
113}
114
115#[derive(Clone, Debug, Eq, PartialEq)]
117pub struct FunctionCallNode<'i, T> {
118 pub name: &'i str,
120 pub name_span: pest::Span<'i>,
122 pub args: Vec<ExpressionNode<'i, T>>,
124 pub keyword_args: Vec<KeywordArgument<'i, T>>,
126 pub args_span: pest::Span<'i>,
128}
129
130#[derive(Clone, Debug, Eq, PartialEq)]
132pub struct KeywordArgument<'i, T> {
133 pub name: &'i str,
135 pub name_span: pest::Span<'i>,
137 pub value: ExpressionNode<'i, T>,
139}
140
141impl<'i, T> FunctionCallNode<'i, T> {
142 pub fn arity(&self) -> usize {
144 self.args.len() + self.keyword_args.len()
145 }
146
147 pub fn expect_no_arguments(&self) -> Result<(), InvalidArguments<'i>> {
149 let ([], []) = self.expect_arguments()?;
150 Ok(())
151 }
152
153 pub fn expect_exact_arguments<const N: usize>(
155 &self,
156 ) -> Result<&[ExpressionNode<'i, T>; N], InvalidArguments<'i>> {
157 let (args, []) = self.expect_arguments()?;
158 Ok(args)
159 }
160
161 #[expect(clippy::type_complexity)]
171 pub fn expect_some_arguments<const N: usize>(
172 &self,
173 ) -> Result<(&[ExpressionNode<'i, T>; N], &[ExpressionNode<'i, T>]), InvalidArguments<'i>> {
174 self.ensure_no_keyword_arguments()?;
175 if self.args.len() >= N {
176 let (required, rest) = self.args.split_at(N);
177 Ok((required.try_into().unwrap(), rest))
178 } else {
179 Err(self.invalid_arguments_count(N, None))
180 }
181 }
182
183 #[expect(clippy::type_complexity)]
185 pub fn expect_arguments<const N: usize, const M: usize>(
186 &self,
187 ) -> Result<
188 (
189 &[ExpressionNode<'i, T>; N],
190 [Option<&ExpressionNode<'i, T>>; M],
191 ),
192 InvalidArguments<'i>,
193 > {
194 self.ensure_no_keyword_arguments()?;
195 let count_range = N..=(N + M);
196 if count_range.contains(&self.args.len()) {
197 let (required, rest) = self.args.split_at(N);
198 let mut optional = rest.iter().map(Some).collect_vec();
199 optional.resize(M, None);
200 Ok((
201 required.try_into().unwrap(),
202 optional.try_into().ok().unwrap(),
203 ))
204 } else {
205 let (min, max) = count_range.into_inner();
206 Err(self.invalid_arguments_count(min, Some(max)))
207 }
208 }
209
210 #[expect(clippy::type_complexity)]
216 pub fn expect_named_arguments<const N: usize, const M: usize>(
217 &self,
218 names: &[&str],
219 ) -> Result<
220 (
221 [&ExpressionNode<'i, T>; N],
222 [Option<&ExpressionNode<'i, T>>; M],
223 ),
224 InvalidArguments<'i>,
225 > {
226 if self.keyword_args.is_empty() {
227 let (required, optional) = self.expect_arguments::<N, M>()?;
228 Ok((required.each_ref(), optional))
229 } else {
230 let (required, optional) = self.expect_named_arguments_vec(names, N, N + M)?;
231 Ok((
232 required.try_into().ok().unwrap(),
233 optional.try_into().ok().unwrap(),
234 ))
235 }
236 }
237
238 #[expect(clippy::type_complexity)]
239 fn expect_named_arguments_vec(
240 &self,
241 names: &[&str],
242 min: usize,
243 max: usize,
244 ) -> Result<
245 (
246 Vec<&ExpressionNode<'i, T>>,
247 Vec<Option<&ExpressionNode<'i, T>>>,
248 ),
249 InvalidArguments<'i>,
250 > {
251 assert!(names.len() <= max);
252
253 if self.args.len() > max {
254 return Err(self.invalid_arguments_count(min, Some(max)));
255 }
256 let mut extracted = Vec::with_capacity(max);
257 extracted.extend(self.args.iter().map(Some));
258 extracted.resize(max, None);
259
260 for arg in &self.keyword_args {
261 let name = arg.name;
262 let span = arg.name_span.start_pos().span(&arg.value.span.end_pos());
263 let pos = names.iter().position(|&n| n == name).ok_or_else(|| {
264 self.invalid_arguments(format!(r#"Unexpected keyword argument "{name}""#), span)
265 })?;
266 if extracted[pos].is_some() {
267 return Err(self.invalid_arguments(
268 format!(r#"Got multiple values for keyword "{name}""#),
269 span,
270 ));
271 }
272 extracted[pos] = Some(&arg.value);
273 }
274
275 let optional = extracted.split_off(min);
276 let required = extracted.into_iter().flatten().collect_vec();
277 if required.len() != min {
278 return Err(self.invalid_arguments_count(min, Some(max)));
279 }
280 Ok((required, optional))
281 }
282
283 fn ensure_no_keyword_arguments(&self) -> Result<(), InvalidArguments<'i>> {
284 if let (Some(first), Some(last)) = (self.keyword_args.first(), self.keyword_args.last()) {
285 let span = first.name_span.start_pos().span(&last.value.span.end_pos());
286 Err(self.invalid_arguments("Unexpected keyword arguments".to_owned(), span))
287 } else {
288 Ok(())
289 }
290 }
291
292 fn invalid_arguments(&self, message: String, span: pest::Span<'i>) -> InvalidArguments<'i> {
293 InvalidArguments {
294 name: self.name,
295 message,
296 span,
297 }
298 }
299
300 fn invalid_arguments_count(&self, min: usize, max: Option<usize>) -> InvalidArguments<'i> {
301 let message = match (min, max) {
302 (min, Some(max)) if min == max => format!("Expected {min} arguments"),
303 (min, Some(max)) => format!("Expected {min} to {max} arguments"),
304 (min, None) => format!("Expected at least {min} arguments"),
305 };
306 self.invalid_arguments(message, self.args_span)
307 }
308
309 fn invalid_arguments_count_with_arities(
310 &self,
311 arities: impl IntoIterator<Item = usize>,
312 ) -> InvalidArguments<'i> {
313 let message = format!("Expected {} arguments", arities.into_iter().join(", "));
314 self.invalid_arguments(message, self.args_span)
315 }
316}
317
318#[derive(Clone, Debug)]
323pub struct InvalidArguments<'i> {
324 pub name: &'i str,
326 pub message: String,
328 pub span: pest::Span<'i>,
330}
331
332pub trait FoldableExpression<'i>: Sized {
334 fn fold<F>(self, folder: &mut F, span: pest::Span<'i>) -> Result<Self, F::Error>
336 where
337 F: ExpressionFolder<'i, Self> + ?Sized;
338}
339
340pub trait ExpressionFolder<'i, T: FoldableExpression<'i>> {
342 type Error;
344
345 fn fold_expression(
348 &mut self,
349 node: ExpressionNode<'i, T>,
350 ) -> Result<ExpressionNode<'i, T>, Self::Error> {
351 let ExpressionNode { kind, span } = node;
352 let kind = kind.fold(self, span)?;
353 Ok(ExpressionNode { kind, span })
354 }
355
356 fn fold_identifier(&mut self, name: &'i str, span: pest::Span<'i>) -> Result<T, Self::Error>;
358
359 fn fold_pattern(
361 &mut self,
362 pattern: Box<PatternNode<'i, T>>,
363 span: pest::Span<'i>,
364 ) -> Result<T, Self::Error>;
365
366 fn fold_function_call(
368 &mut self,
369 function: Box<FunctionCallNode<'i, T>>,
370 span: pest::Span<'i>,
371 ) -> Result<T, Self::Error>;
372}
373
374pub fn fold_expression_nodes<'i, F, T>(
376 folder: &mut F,
377 nodes: Vec<ExpressionNode<'i, T>>,
378) -> Result<Vec<ExpressionNode<'i, T>>, F::Error>
379where
380 F: ExpressionFolder<'i, T> + ?Sized,
381 T: FoldableExpression<'i>,
382{
383 nodes
384 .into_iter()
385 .map(|node| folder.fold_expression(node))
386 .try_collect()
387}
388
389pub fn fold_pattern_value<'i, F, T>(
391 folder: &mut F,
392 pattern: PatternNode<'i, T>,
393) -> Result<PatternNode<'i, T>, F::Error>
394where
395 F: ExpressionFolder<'i, T> + ?Sized,
396 T: FoldableExpression<'i>,
397{
398 Ok(PatternNode {
399 name: pattern.name,
400 name_span: pattern.name_span,
401 value: folder.fold_expression(pattern.value)?,
402 })
403}
404
405pub fn fold_function_call_args<'i, F, T>(
407 folder: &mut F,
408 function: FunctionCallNode<'i, T>,
409) -> Result<FunctionCallNode<'i, T>, F::Error>
410where
411 F: ExpressionFolder<'i, T> + ?Sized,
412 T: FoldableExpression<'i>,
413{
414 Ok(FunctionCallNode {
415 name: function.name,
416 name_span: function.name_span,
417 args: fold_expression_nodes(folder, function.args)?,
418 keyword_args: function
419 .keyword_args
420 .into_iter()
421 .map(|arg| {
422 Ok(KeywordArgument {
423 name: arg.name,
424 name_span: arg.name_span,
425 value: folder.fold_expression(arg.value)?,
426 })
427 })
428 .try_collect()?,
429 args_span: function.args_span,
430 })
431}
432
433#[derive(Debug)]
435pub struct StringLiteralParser<R> {
436 pub content_rule: R,
438 pub escape_rule: R,
440}
441
442impl<R: RuleType> StringLiteralParser<R> {
443 pub fn parse(&self, pairs: Pairs<R>) -> String {
445 let mut result = String::new();
446 for part in pairs {
447 if part.as_rule() == self.content_rule {
448 result.push_str(part.as_str());
449 } else if part.as_rule() == self.escape_rule {
450 match &part.as_str()[1..] {
451 "\"" => result.push('"'),
452 "\\" => result.push('\\'),
453 "t" => result.push('\t'),
454 "r" => result.push('\r'),
455 "n" => result.push('\n'),
456 "0" => result.push('\0'),
457 "e" => result.push('\x1b'),
458 hex if hex.starts_with('x') => {
459 result.push(char::from(
460 u8::from_str_radix(&hex[1..], 16).expect("hex characters"),
461 ));
462 }
463 char => panic!("invalid escape: \\{char:?}"),
464 }
465 } else {
466 panic!("unexpected part of string: {part:?}");
467 }
468 }
469 result
470 }
471}
472
473pub fn escape_string(unescaped: &str) -> String {
475 let mut escaped = String::with_capacity(unescaped.len());
476 for c in unescaped.chars() {
477 match c {
478 '"' => escaped.push_str(r#"\""#),
479 '\\' => escaped.push_str(r#"\\"#),
480 '\t' => escaped.push_str(r#"\t"#),
481 '\r' => escaped.push_str(r#"\r"#),
482 '\n' => escaped.push_str(r#"\n"#),
483 '\0' => escaped.push_str(r#"\0"#),
484 c if c.is_ascii_control() => {
485 for b in ascii::escape_default(c as u8) {
486 escaped.push(b as char);
487 }
488 }
489 c => escaped.push(c),
490 }
491 }
492 escaped
493}
494
495#[derive(Debug)]
497pub struct FunctionCallParser<R> {
498 pub function_name_rule: R,
500 pub function_arguments_rule: R,
502 pub keyword_argument_rule: R,
504 pub argument_name_rule: R,
506 pub argument_value_rule: R,
508}
509
510impl<R: RuleType> FunctionCallParser<R> {
511 pub fn parse<'i, T, E: From<InvalidArguments<'i>>>(
513 &self,
514 pair: Pair<'i, R>,
515 parse_name: impl Fn(Pair<'i, R>) -> Result<&'i str, E>,
518 parse_value: impl Fn(Pair<'i, R>) -> Result<ExpressionNode<'i, T>, E>,
519 ) -> Result<FunctionCallNode<'i, T>, E> {
520 let [name_pair, args_pair] = pair.into_inner().collect_array().unwrap();
521 assert_eq!(name_pair.as_rule(), self.function_name_rule);
522 assert_eq!(args_pair.as_rule(), self.function_arguments_rule);
523 let name_span = name_pair.as_span();
524 let args_span = args_pair.as_span();
525 let function_name = parse_name(name_pair)?;
526 let mut args = Vec::new();
527 let mut keyword_args = Vec::new();
528 for pair in args_pair.into_inner() {
529 let span = pair.as_span();
530 if pair.as_rule() == self.argument_value_rule {
531 if !keyword_args.is_empty() {
532 return Err(InvalidArguments {
533 name: function_name,
534 message: "Positional argument follows keyword argument".to_owned(),
535 span,
536 }
537 .into());
538 }
539 args.push(parse_value(pair)?);
540 } else if pair.as_rule() == self.keyword_argument_rule {
541 let [name_pair, value_pair] = pair.into_inner().collect_array().unwrap();
542 assert_eq!(name_pair.as_rule(), self.argument_name_rule);
543 assert_eq!(value_pair.as_rule(), self.argument_value_rule);
544 let name_span = name_pair.as_span();
545 let arg = KeywordArgument {
546 name: parse_name(name_pair)?,
547 name_span,
548 value: parse_value(value_pair)?,
549 };
550 keyword_args.push(arg);
551 } else {
552 panic!("unexpected argument rule {pair:?}");
553 }
554 }
555 Ok(FunctionCallNode {
556 name: function_name,
557 name_span,
558 args,
559 keyword_args,
560 args_span,
561 })
562 }
563}
564
565#[derive(Clone, Debug, Default)]
567pub struct AliasesMap<P, V> {
568 symbol_aliases: HashMap<String, V>,
569 pattern_aliases: HashMap<String, (String, V)>,
571 function_aliases: HashMap<String, Vec<(Vec<String>, V)>>,
573 parser: P,
575}
576
577impl<P, V> AliasesMap<P, V> {
578 pub fn new() -> Self
580 where
581 P: Default,
582 {
583 Self {
584 symbol_aliases: Default::default(),
585 pattern_aliases: Default::default(),
586 function_aliases: Default::default(),
587 parser: Default::default(),
588 }
589 }
590
591 pub fn insert(&mut self, decl: impl AsRef<str>, defn: impl Into<V>) -> Result<(), P::Error>
596 where
597 P: AliasDeclarationParser,
598 {
599 match self.parser.parse_declaration(decl.as_ref())? {
600 AliasDeclaration::Symbol(name) => {
601 self.symbol_aliases.insert(name, defn.into());
602 }
603 AliasDeclaration::Pattern(name, param) => {
604 self.pattern_aliases.insert(name, (param, defn.into()));
605 }
606 AliasDeclaration::Function(name, params) => {
607 let overloads = self.function_aliases.entry(name).or_default();
608 match overloads.binary_search_by_key(¶ms.len(), |(params, _)| params.len()) {
609 Ok(i) => overloads[i] = (params, defn.into()),
610 Err(i) => overloads.insert(i, (params, defn.into())),
611 }
612 }
613 }
614 Ok(())
615 }
616
617 pub fn symbol_names(&self) -> impl Iterator<Item = &str> {
619 self.symbol_aliases.keys().map(|n| n.as_ref())
620 }
621
622 pub fn pattern_names(&self) -> impl Iterator<Item = &str> {
624 self.pattern_aliases.keys().map(|n| n.as_ref())
625 }
626
627 pub fn function_names(&self) -> impl Iterator<Item = &str> {
629 self.function_aliases.keys().map(|n| n.as_ref())
630 }
631
632 pub fn get_symbol(&self, name: &str) -> Option<(AliasId<'_>, &V)> {
634 self.symbol_aliases
635 .get_key_value(name)
636 .map(|(name, defn)| (AliasId::Symbol(name), defn))
637 }
638
639 pub fn get_pattern(&self, name: &str) -> Option<(AliasId<'_>, &str, &V)> {
642 self.pattern_aliases
643 .get_key_value(name)
644 .map(|(name, (param, defn))| (AliasId::Pattern(name, param), param.as_ref(), defn))
645 }
646
647 pub fn get_function(&self, name: &str, arity: usize) -> Option<(AliasId<'_>, &[String], &V)> {
650 let overloads = self.get_function_overloads(name)?;
651 overloads.find_by_arity(arity)
652 }
653
654 fn get_function_overloads(&self, name: &str) -> Option<AliasFunctionOverloads<'_, V>> {
656 let (name, overloads) = self.function_aliases.get_key_value(name)?;
657 Some(AliasFunctionOverloads { name, overloads })
658 }
659}
660
661#[derive(Clone, Debug)]
662struct AliasFunctionOverloads<'a, V> {
663 name: &'a String,
664 overloads: &'a Vec<(Vec<String>, V)>,
665}
666
667impl<'a, V> AliasFunctionOverloads<'a, V> {
668 fn arities(&self) -> impl DoubleEndedIterator<Item = usize> + ExactSizeIterator {
669 self.overloads.iter().map(|(params, _)| params.len())
670 }
671
672 fn min_arity(&self) -> usize {
673 self.arities().next().unwrap()
674 }
675
676 fn max_arity(&self) -> usize {
677 self.arities().next_back().unwrap()
678 }
679
680 fn find_by_arity(&self, arity: usize) -> Option<(AliasId<'a>, &'a [String], &'a V)> {
681 let index = self
682 .overloads
683 .binary_search_by_key(&arity, |(params, _)| params.len())
684 .ok()?;
685 let (params, defn) = &self.overloads[index];
686 Some((AliasId::Function(self.name, params), params, defn))
690 }
691}
692
693#[derive(Clone, Copy, Debug, Eq, PartialEq)]
695pub enum AliasId<'a> {
696 Symbol(&'a str),
698 Pattern(&'a str, &'a str),
700 Function(&'a str, &'a [String]),
702 Parameter(&'a str),
704}
705
706impl fmt::Display for AliasId<'_> {
707 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
708 match self {
709 Self::Symbol(name) => write!(f, "{name}"),
710 Self::Pattern(name, param) => write!(f, "{name}:{param}"),
711 Self::Function(name, params) => {
712 write!(f, "{name}({params})", params = params.join(", "))
713 }
714 Self::Parameter(name) => write!(f, "{name}"),
715 }
716 }
717}
718
719#[derive(Clone, Debug)]
721pub enum AliasDeclaration {
722 Symbol(String),
724 Pattern(String, String),
726 Function(String, Vec<String>),
728}
729
730pub trait AliasDeclarationParser {
735 type Error;
737
738 fn parse_declaration(&self, source: &str) -> Result<AliasDeclaration, Self::Error>;
740}
741
742pub trait AliasDefinitionParser {
744 type Output<'i>;
746 type Error;
748
749 fn parse_definition<'i>(
751 &self,
752 source: &'i str,
753 ) -> Result<ExpressionNode<'i, Self::Output<'i>>, Self::Error>;
754}
755
756pub trait AliasExpandableExpression<'i>: FoldableExpression<'i> {
758 fn identifier(name: &'i str) -> Self;
760 fn pattern(pattern: Box<PatternNode<'i, Self>>) -> Self;
762 fn function_call(function: Box<FunctionCallNode<'i, Self>>) -> Self;
764 fn alias_expanded(id: AliasId<'i>, subst: Box<ExpressionNode<'i, Self>>) -> Self;
766}
767
768pub trait AliasExpandError: Sized {
770 fn invalid_arguments(err: InvalidArguments<'_>) -> Self;
772 fn recursive_expansion(id: AliasId<'_>, span: pest::Span<'_>) -> Self;
774 fn within_alias_expansion(self, id: AliasId<'_>, span: pest::Span<'_>) -> Self;
776}
777
778#[derive(Debug)]
780struct AliasExpander<'i, 'a, T, P> {
781 aliases_map: &'i AliasesMap<P, String>,
783 locals: &'a HashMap<&'i str, ExpressionNode<'i, T>>,
785 states: Vec<AliasExpandingState<'i, T>>,
787}
788
789#[derive(Debug)]
790struct AliasExpandingState<'i, T> {
791 id: AliasId<'i>,
792 locals: HashMap<&'i str, ExpressionNode<'i, T>>,
793}
794
795impl<'i, T, P, E> AliasExpander<'i, '_, T, P>
796where
797 T: AliasExpandableExpression<'i> + Clone,
798 P: AliasDefinitionParser<Output<'i> = T, Error = E>,
799 E: AliasExpandError,
800{
801 fn current_locals(&self) -> &HashMap<&'i str, ExpressionNode<'i, T>> {
803 self.states.last().map_or(self.locals, |s| &s.locals)
804 }
805
806 fn expand_defn(
807 &mut self,
808 id: AliasId<'i>,
809 defn: &'i str,
810 locals: HashMap<&'i str, ExpressionNode<'i, T>>,
811 span: pest::Span<'i>,
812 ) -> Result<T, E> {
813 if self.states.iter().any(|s| s.id == id) {
815 return Err(E::recursive_expansion(id, span));
816 }
817 self.states.push(AliasExpandingState { id, locals });
818 let result = self
820 .aliases_map
821 .parser
822 .parse_definition(defn)
823 .and_then(|node| self.fold_expression(node))
824 .map(|node| T::alias_expanded(id, Box::new(node)))
825 .map_err(|e| e.within_alias_expansion(id, span));
826 self.states.pop();
827 result
828 }
829}
830
831impl<'i, T, P, E> ExpressionFolder<'i, T> for AliasExpander<'i, '_, T, P>
832where
833 T: AliasExpandableExpression<'i> + Clone,
834 P: AliasDefinitionParser<Output<'i> = T, Error = E>,
835 E: AliasExpandError,
836{
837 type Error = E;
838
839 fn fold_identifier(&mut self, name: &'i str, span: pest::Span<'i>) -> Result<T, Self::Error> {
840 if let Some(subst) = self.current_locals().get(name) {
841 let id = AliasId::Parameter(name);
842 Ok(T::alias_expanded(id, Box::new(subst.clone())))
843 } else if let Some((id, defn)) = self.aliases_map.get_symbol(name) {
844 let locals = HashMap::new(); self.expand_defn(id, defn, locals, span)
846 } else {
847 Ok(T::identifier(name))
848 }
849 }
850
851 fn fold_pattern(
852 &mut self,
853 pattern: Box<PatternNode<'i, T>>,
854 span: pest::Span<'i>,
855 ) -> Result<T, Self::Error> {
856 if let Some((id, param, defn)) = self.aliases_map.get_pattern(pattern.name) {
857 let arg = self.fold_expression(pattern.value)?;
860 let locals = HashMap::from([(param, arg)]);
861 self.expand_defn(id, defn, locals, span)
862 } else {
863 let pattern = Box::new(fold_pattern_value(self, *pattern)?);
864 Ok(T::pattern(pattern))
865 }
866 }
867
868 fn fold_function_call(
869 &mut self,
870 function: Box<FunctionCallNode<'i, T>>,
871 span: pest::Span<'i>,
872 ) -> Result<T, Self::Error> {
873 if let Some(overloads) = self.aliases_map.get_function_overloads(function.name) {
876 function
878 .ensure_no_keyword_arguments()
879 .map_err(E::invalid_arguments)?;
880 let Some((id, params, defn)) = overloads.find_by_arity(function.arity()) else {
881 let min = overloads.min_arity();
882 let max = overloads.max_arity();
883 let err = if max - min + 1 == overloads.arities().len() {
884 function.invalid_arguments_count(min, Some(max))
885 } else {
886 function.invalid_arguments_count_with_arities(overloads.arities())
887 };
888 return Err(E::invalid_arguments(err));
889 };
890 let args = fold_expression_nodes(self, function.args)?;
893 let locals = params.iter().map(|s| s.as_str()).zip(args).collect();
894 self.expand_defn(id, defn, locals, span)
895 } else {
896 let function = Box::new(fold_function_call_args(self, *function)?);
897 Ok(T::function_call(function))
898 }
899 }
900}
901
902pub fn expand_aliases<'i, T, P>(
904 node: ExpressionNode<'i, T>,
905 aliases_map: &'i AliasesMap<P, String>,
906) -> Result<ExpressionNode<'i, T>, P::Error>
907where
908 T: AliasExpandableExpression<'i> + Clone,
909 P: AliasDefinitionParser<Output<'i> = T>,
910 P::Error: AliasExpandError,
911{
912 expand_aliases_with_locals(node, aliases_map, &HashMap::new())
913}
914
915pub fn expand_aliases_with_locals<'i, T, P>(
920 node: ExpressionNode<'i, T>,
921 aliases_map: &'i AliasesMap<P, String>,
922 locals: &HashMap<&'i str, ExpressionNode<'i, T>>,
923) -> Result<ExpressionNode<'i, T>, P::Error>
924where
925 T: AliasExpandableExpression<'i> + Clone,
926 P: AliasDefinitionParser<Output<'i> = T>,
927 P::Error: AliasExpandError,
928{
929 let mut expander = AliasExpander {
930 aliases_map,
931 locals,
932 states: Vec::new(),
933 };
934 expander.fold_expression(node)
935}
936
937pub fn collect_similar<I>(name: &str, candidates: I) -> Vec<String>
939where
940 I: IntoIterator,
941 I::Item: AsRef<str>,
942{
943 candidates
944 .into_iter()
945 .filter(|cand| {
946 strsim::jaro(name, cand.as_ref()) > 0.7
948 })
949 .map(|s| s.as_ref().to_owned())
950 .sorted_unstable()
951 .collect()
952}
953
954#[cfg(test)]
955mod tests {
956 use super::*;
957
958 #[test]
959 fn test_expect_arguments() {
960 fn empty_span() -> pest::Span<'static> {
961 pest::Span::new("", 0, 0).unwrap()
962 }
963
964 fn function(
965 name: &'static str,
966 args: impl Into<Vec<ExpressionNode<'static, u32>>>,
967 keyword_args: impl Into<Vec<KeywordArgument<'static, u32>>>,
968 ) -> FunctionCallNode<'static, u32> {
969 FunctionCallNode {
970 name,
971 name_span: empty_span(),
972 args: args.into(),
973 keyword_args: keyword_args.into(),
974 args_span: empty_span(),
975 }
976 }
977
978 fn value(v: u32) -> ExpressionNode<'static, u32> {
979 ExpressionNode::new(v, empty_span())
980 }
981
982 fn keyword(name: &'static str, v: u32) -> KeywordArgument<'static, u32> {
983 KeywordArgument {
984 name,
985 name_span: empty_span(),
986 value: value(v),
987 }
988 }
989
990 let f = function("foo", [], []);
991 assert!(f.expect_no_arguments().is_ok());
992 assert!(f.expect_some_arguments::<0>().is_ok());
993 assert!(f.expect_arguments::<0, 0>().is_ok());
994 assert!(f.expect_named_arguments::<0, 0>(&[]).is_ok());
995
996 let f = function("foo", [value(0)], []);
997 assert!(f.expect_no_arguments().is_err());
998 assert_eq!(
999 f.expect_some_arguments::<0>().unwrap(),
1000 (&[], [value(0)].as_slice())
1001 );
1002 assert_eq!(
1003 f.expect_some_arguments::<1>().unwrap(),
1004 (&[value(0)], [].as_slice())
1005 );
1006 assert!(f.expect_arguments::<0, 0>().is_err());
1007 assert_eq!(
1008 f.expect_arguments::<0, 1>().unwrap(),
1009 (&[], [Some(&value(0))])
1010 );
1011 assert_eq!(f.expect_arguments::<1, 1>().unwrap(), (&[value(0)], [None]));
1012 assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
1013 assert_eq!(
1014 f.expect_named_arguments::<0, 1>(&["a"]).unwrap(),
1015 ([], [Some(&value(0))])
1016 );
1017 assert_eq!(
1018 f.expect_named_arguments::<1, 0>(&["a"]).unwrap(),
1019 ([&value(0)], [])
1020 );
1021
1022 let f = function("foo", [], [keyword("a", 0)]);
1023 assert!(f.expect_no_arguments().is_err());
1024 assert!(f.expect_some_arguments::<1>().is_err());
1025 assert!(f.expect_arguments::<0, 1>().is_err());
1026 assert!(f.expect_arguments::<1, 0>().is_err());
1027 assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
1028 assert!(f.expect_named_arguments::<0, 1>(&[]).is_err());
1029 assert!(f.expect_named_arguments::<1, 0>(&[]).is_err());
1030 assert_eq!(
1031 f.expect_named_arguments::<1, 0>(&["a"]).unwrap(),
1032 ([&value(0)], [])
1033 );
1034 assert_eq!(
1035 f.expect_named_arguments::<1, 1>(&["a", "b"]).unwrap(),
1036 ([&value(0)], [None])
1037 );
1038 assert!(f.expect_named_arguments::<1, 1>(&["b", "a"]).is_err());
1039
1040 let f = function("foo", [value(0)], [keyword("a", 1), keyword("b", 2)]);
1041 assert!(f.expect_named_arguments::<0, 0>(&[]).is_err());
1042 assert!(f.expect_named_arguments::<1, 1>(&["a", "b"]).is_err());
1043 assert_eq!(
1044 f.expect_named_arguments::<1, 2>(&["c", "a", "b"]).unwrap(),
1045 ([&value(0)], [Some(&value(1)), Some(&value(2))])
1046 );
1047 assert_eq!(
1048 f.expect_named_arguments::<2, 1>(&["c", "b", "a"]).unwrap(),
1049 ([&value(0), &value(2)], [Some(&value(1))])
1050 );
1051 assert_eq!(
1052 f.expect_named_arguments::<0, 3>(&["c", "b", "a"]).unwrap(),
1053 ([], [Some(&value(0)), Some(&value(2)), Some(&value(1))])
1054 );
1055
1056 let f = function("foo", [], [keyword("a", 0), keyword("a", 1)]);
1057 assert!(f.expect_named_arguments::<1, 1>(&["", "a"]).is_err());
1058 }
1059}