Skip to main content

tree_sitter_cli/
test.rs

1use std::{
2    collections::BTreeMap,
3    ffi::OsStr,
4    fmt::Display as _,
5    fs,
6    io::{self, Write},
7    path::{Path, PathBuf},
8    str,
9    sync::LazyLock,
10    time::Duration,
11};
12
13use anstyle::AnsiColor;
14use anyhow::{anyhow, Context, Result};
15use clap::ValueEnum;
16use indoc::indoc;
17use regex::{
18    bytes::{Regex as ByteRegex, RegexBuilder as ByteRegexBuilder},
19    Regex,
20};
21use schemars::{JsonSchema, Schema, SchemaGenerator};
22use serde::Serialize;
23use similar::{ChangeTag, TextDiff};
24use tree_sitter::{format_sexp, Language, LogType, Parser, Query, Tree};
25use walkdir::WalkDir;
26
27use super::util;
28use crate::{
29    logger::paint,
30    parse::{
31        render_cst, ParseDebugType, ParseFileOptions, ParseOutput, ParseStats, ParseTheme, Stats,
32    },
33};
34
35static HEADER_REGEX: LazyLock<ByteRegex> = LazyLock::new(|| {
36    ByteRegexBuilder::new(
37        r"^(?x)
38           (?P<equals>(?:=+){3,})
39           (?P<suffix1>[^=\r\n][^\r\n]*)?
40           \r?\n
41           (?P<test_name_and_markers>(?:([^=\r\n]|\s+:)[^\r\n]*\r?\n)+)
42           ===+
43           (?P<suffix2>[^=\r\n][^\r\n]*)?\r?\n",
44    )
45    .multi_line(true)
46    .build()
47    .unwrap()
48});
49
50static DIVIDER_REGEX: LazyLock<ByteRegex> = LazyLock::new(|| {
51    ByteRegexBuilder::new(r"^(?P<hyphens>(?:-+){3,})(?P<suffix>[^-\r\n][^\r\n]*)?\r?\n")
52        .multi_line(true)
53        .build()
54        .unwrap()
55});
56
57static COMMENT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?m)^\s*;.*$").unwrap());
58
59static WHITESPACE_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\s+").unwrap());
60
61static SEXP_FIELD_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r" \w+: \(").unwrap());
62
63static POINT_REGEX: LazyLock<Regex> =
64    LazyLock::new(|| Regex::new(r"\s*\[\s*\d+\s*,\s*\d+\s*\]\s*").unwrap());
65
66#[derive(Debug, PartialEq, Eq)]
67pub enum TestEntry {
68    Group {
69        name: String,
70        children: Vec<Self>,
71        file_path: Option<PathBuf>,
72    },
73    Example {
74        name: String,
75        input: Vec<u8>,
76        output: String,
77        header_delim_len: usize,
78        divider_delim_len: usize,
79        has_fields: bool,
80        attributes_str: String,
81        attributes: TestAttributes,
82        file_name: Option<String>,
83    },
84}
85
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct TestAttributes {
88    pub skip: bool,
89    pub platform: bool,
90    pub fail_fast: bool,
91    pub error: bool,
92    pub cst: bool,
93    pub languages: Vec<Box<str>>,
94}
95
96impl Default for TestEntry {
97    fn default() -> Self {
98        Self::Group {
99            name: String::new(),
100            children: Vec::new(),
101            file_path: None,
102        }
103    }
104}
105
106impl Default for TestAttributes {
107    fn default() -> Self {
108        Self {
109            skip: false,
110            platform: true,
111            fail_fast: false,
112            error: false,
113            cst: false,
114            languages: vec!["".into()],
115        }
116    }
117}
118
119#[derive(ValueEnum, Default, Debug, Copy, Clone, PartialEq, Eq, Serialize)]
120pub enum TestStats {
121    All,
122    #[default]
123    OutliersAndTotal,
124    TotalOnly,
125}
126
127pub struct TestOptions<'a> {
128    pub path: PathBuf,
129    pub debug: bool,
130    pub debug_graph: bool,
131    pub include: Option<Regex>,
132    pub exclude: Option<Regex>,
133    pub file_name: Option<String>,
134    pub update: bool,
135    pub open_log: bool,
136    pub languages: BTreeMap<&'a str, &'a Language>,
137    pub color: bool,
138    pub show_fields: bool,
139    pub overview_only: bool,
140}
141
142/// A stateful object used to collect results from running a grammar's test suite
143#[derive(Debug, Default, Serialize, JsonSchema)]
144pub struct TestSummary {
145    // Parse test results and associated data
146    #[schemars(schema_with = "schema_as_array")]
147    #[serde(serialize_with = "serialize_as_array")]
148    pub parse_results: TestResultHierarchy,
149    pub parse_failures: Vec<TestFailure>,
150    pub parse_stats: Stats,
151    #[schemars(skip)]
152    #[serde(skip)]
153    pub has_parse_errors: bool,
154    #[schemars(skip)]
155    #[serde(skip)]
156    pub parse_stat_display: TestStats,
157
158    // Other test results
159    #[schemars(schema_with = "schema_as_array")]
160    #[serde(serialize_with = "serialize_as_array")]
161    pub highlight_results: TestResultHierarchy,
162    #[schemars(schema_with = "schema_as_array")]
163    #[serde(serialize_with = "serialize_as_array")]
164    pub tag_results: TestResultHierarchy,
165    #[schemars(schema_with = "schema_as_array")]
166    #[serde(serialize_with = "serialize_as_array")]
167    pub query_results: TestResultHierarchy,
168
169    // Data used during construction
170    #[schemars(skip)]
171    #[serde(skip)]
172    pub test_num: usize,
173    // Options passed in from the CLI which control how the summary is displayed
174    #[schemars(skip)]
175    #[serde(skip)]
176    pub color: bool,
177    #[schemars(skip)]
178    #[serde(skip)]
179    pub overview_only: bool,
180    #[schemars(skip)]
181    #[serde(skip)]
182    pub update: bool,
183    #[schemars(skip)]
184    #[serde(skip)]
185    pub json: bool,
186}
187
188impl TestSummary {
189    #[must_use]
190    pub fn new(
191        color: bool,
192        stat_display: TestStats,
193        parse_update: bool,
194        overview_only: bool,
195        json_summary: bool,
196    ) -> Self {
197        Self {
198            color,
199            parse_stat_display: stat_display,
200            update: parse_update,
201            overview_only,
202            json: json_summary,
203            test_num: 1,
204            ..Default::default()
205        }
206    }
207}
208
209#[derive(Debug, Default, JsonSchema)]
210pub struct TestResultHierarchy {
211    root_group: Vec<TestResult>,
212    traversal_idxs: Vec<usize>,
213}
214
215fn serialize_as_array<S>(results: &TestResultHierarchy, serializer: S) -> Result<S::Ok, S::Error>
216where
217    S: serde::Serializer,
218{
219    results.root_group.serialize(serializer)
220}
221
222fn schema_as_array(gen: &mut SchemaGenerator) -> Schema {
223    gen.subschema_for::<Vec<TestResult>>()
224}
225
226/// Stores arbitrarily nested parent test groups and child cases. Supports creation
227/// in DFS traversal order
228impl TestResultHierarchy {
229    /// Signifies the start of a new group's traversal during construction.
230    fn push_traversal(&mut self, idx: usize) {
231        self.traversal_idxs.push(idx);
232    }
233
234    /// Signifies the end of the current group's traversal during construction.
235    /// Must be paired with a prior call to [`TestResultHierarchy::add_group`].
236    pub fn pop_traversal(&mut self) {
237        self.traversal_idxs.pop();
238    }
239
240    /// Adds a new group as a child of the current group. Caller is responsible
241    /// for calling [`TestResultHierarchy::pop_traversal`] once the group is done
242    /// being traversed.
243    pub fn add_group(&mut self, group_name: &str) {
244        let new_group_idx = self.curr_group_len();
245        self.push(TestResult {
246            name: group_name.to_string(),
247            info: TestInfo::Group {
248                children: Vec::new(),
249            },
250        });
251        self.push_traversal(new_group_idx);
252    }
253
254    /// Adds a new test example as a child of the current group.
255    /// Asserts that `test_case.info` is not [`TestInfo::Group`].
256    pub fn add_case(&mut self, test_case: TestResult) {
257        assert!(!matches!(test_case.info, TestInfo::Group { .. }));
258        self.push(test_case);
259    }
260
261    /// Adds a new `TestResult` to the current group.
262    fn push(&mut self, result: TestResult) {
263        // If there are no traversal steps, we're adding to the root
264        if self.traversal_idxs.is_empty() {
265            self.root_group.push(result);
266            return;
267        }
268
269        #[allow(clippy::manual_let_else)]
270        let mut curr_group = match self.root_group[self.traversal_idxs[0]].info {
271            TestInfo::Group { ref mut children } => children,
272            _ => unreachable!(),
273        };
274        for idx in self.traversal_idxs.iter().skip(1) {
275            curr_group = match curr_group[*idx].info {
276                TestInfo::Group { ref mut children } => children,
277                _ => unreachable!(),
278            };
279        }
280
281        curr_group.push(result);
282    }
283
284    fn curr_group_len(&self) -> usize {
285        if self.traversal_idxs.is_empty() {
286            return self.root_group.len();
287        }
288
289        #[allow(clippy::manual_let_else)]
290        let mut curr_group = match self.root_group[self.traversal_idxs[0]].info {
291            TestInfo::Group { ref children } => children,
292            _ => unreachable!(),
293        };
294        for idx in self.traversal_idxs.iter().skip(1) {
295            curr_group = match curr_group[*idx].info {
296                TestInfo::Group { ref children } => children,
297                _ => unreachable!(),
298            };
299        }
300        curr_group.len()
301    }
302
303    #[allow(clippy::iter_without_into_iter)]
304    #[must_use]
305    pub fn iter(&self) -> TestResultIterWithDepth<'_> {
306        let mut stack = Vec::with_capacity(self.root_group.len());
307        for child in self.root_group.iter().rev() {
308            stack.push((0, child));
309        }
310        TestResultIterWithDepth { stack }
311    }
312}
313
314pub struct TestResultIterWithDepth<'a> {
315    stack: Vec<(usize, &'a TestResult)>,
316}
317
318impl<'a> Iterator for TestResultIterWithDepth<'a> {
319    type Item = (usize, &'a TestResult);
320
321    fn next(&mut self) -> Option<Self::Item> {
322        self.stack.pop().inspect(|(depth, result)| {
323            if let TestInfo::Group { children } = &result.info {
324                for child in children.iter().rev() {
325                    self.stack.push((depth + 1, child));
326                }
327            }
328        })
329    }
330}
331
332#[derive(Debug, Serialize, JsonSchema)]
333pub struct TestResult {
334    pub name: String,
335    #[schemars(flatten)]
336    #[serde(flatten)]
337    pub info: TestInfo,
338}
339
340#[derive(Debug, Serialize, JsonSchema)]
341#[schemars(untagged)]
342#[serde(untagged)]
343pub enum TestInfo {
344    Group {
345        children: Vec<TestResult>,
346    },
347    ParseTest {
348        outcome: TestOutcome,
349        // True parse rate, adjusted parse rate
350        #[schemars(schema_with = "parse_rate_schema")]
351        #[serde(serialize_with = "serialize_parse_rates")]
352        parse_rate: Option<(f64, f64)>,
353        test_num: usize,
354    },
355    AssertionTest {
356        outcome: TestOutcome,
357        test_num: usize,
358    },
359}
360
361fn serialize_parse_rates<S>(
362    parse_rate: &Option<(f64, f64)>,
363    serializer: S,
364) -> Result<S::Ok, S::Error>
365where
366    S: serde::Serializer,
367{
368    match parse_rate {
369        None => serializer.serialize_none(),
370        Some((first, _)) => serializer.serialize_some(first),
371    }
372}
373
374fn parse_rate_schema(gen: &mut SchemaGenerator) -> Schema {
375    gen.subschema_for::<Option<f64>>()
376}
377
378#[derive(Debug, Clone, Eq, PartialEq, Serialize, JsonSchema)]
379pub enum TestOutcome {
380    // Parse outcomes
381    Passed,
382    Failed,
383    Updated,
384    Skipped,
385    Platform,
386
387    // Highlight/Tag/Query outcomes
388    AssertionPassed { assertion_count: usize },
389    AssertionFailed { error: String },
390}
391
392impl TestSummary {
393    fn fmt_parse_results(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        let (count, total_adj_parse_time) = self
395            .parse_results
396            .iter()
397            .filter_map(|(_, result)| match result.info {
398                TestInfo::Group { .. } => None,
399                TestInfo::ParseTest { parse_rate, .. } => parse_rate,
400                _ => unreachable!(),
401            })
402            .fold((0usize, 0.0f64), |(count, rate_accum), (_, adj_rate)| {
403                (count + 1, rate_accum + adj_rate)
404            });
405
406        let avg = total_adj_parse_time / count as f64;
407        let std_dev = {
408            let variance = self
409                .parse_results
410                .iter()
411                .filter_map(|(_, result)| match result.info {
412                    TestInfo::Group { .. } => None,
413                    TestInfo::ParseTest { parse_rate, .. } => parse_rate,
414                    _ => unreachable!(),
415                })
416                .map(|(_, rate_i)| (rate_i - avg).powi(2))
417                .sum::<f64>()
418                / count as f64;
419            variance.sqrt()
420        };
421
422        for (depth, entry) in self.parse_results.iter() {
423            write!(f, "{}", "  ".repeat(depth + 1))?;
424            match &entry.info {
425                TestInfo::Group { .. } => writeln!(f, "{}:", entry.name)?,
426                TestInfo::ParseTest {
427                    outcome,
428                    parse_rate,
429                    test_num,
430                } => {
431                    let (color, result_char) = match outcome {
432                        TestOutcome::Passed => (AnsiColor::Green, "✓"),
433                        TestOutcome::Failed => (AnsiColor::Red, "✗"),
434                        TestOutcome::Updated => (AnsiColor::Blue, "✓"),
435                        TestOutcome::Skipped => (AnsiColor::Yellow, "⌀"),
436                        TestOutcome::Platform => (AnsiColor::Magenta, "⌀"),
437                        _ => unreachable!(),
438                    };
439                    let stat_display = match (self.parse_stat_display, parse_rate) {
440                        (TestStats::TotalOnly, _) | (_, None) => String::new(),
441                        (display, Some((true_rate, adj_rate))) => {
442                            let mut stats = if display == TestStats::All {
443                                format!(" ({true_rate:.3} bytes/ms)")
444                            } else {
445                                String::new()
446                            };
447                            // 3 standard deviations below the mean, aka the "Empirical Rule"
448                            if *adj_rate < 3.0f64.mul_add(-std_dev, avg) {
449                                stats += &paint(
450                                    self.color.then_some(AnsiColor::Yellow),
451                                    &format!(
452                                        " -- Warning: Slow parse rate ({true_rate:.3} bytes/ms)"
453                                    ),
454                                );
455                            }
456                            stats
457                        }
458                    };
459                    writeln!(
460                        f,
461                        "{test_num:>3}. {result_char} {}{stat_display}",
462                        paint(self.color.then_some(color), &entry.name),
463                    )?;
464                }
465                TestInfo::AssertionTest { .. } => unreachable!(),
466            }
467        }
468
469        // Parse failure info
470        if !self.parse_failures.is_empty() && self.update && !self.has_parse_errors {
471            writeln!(
472                f,
473                "\n{} update{}:\n",
474                self.parse_failures.len(),
475                if self.parse_failures.len() == 1 {
476                    ""
477                } else {
478                    "s"
479                }
480            )?;
481
482            for (i, TestFailure { name, .. }) in self.parse_failures.iter().enumerate() {
483                writeln!(f, "  {}. {name}", i + 1)?;
484            }
485        } else if !self.parse_failures.is_empty() && !self.overview_only {
486            if !self.has_parse_errors {
487                writeln!(
488                    f,
489                    "\n{} failure{}:",
490                    self.parse_failures.len(),
491                    if self.parse_failures.len() == 1 {
492                        ""
493                    } else {
494                        "s"
495                    }
496                )?;
497            }
498
499            if self.color {
500                DiffKey.fmt(f)?;
501            }
502            for (
503                i,
504                TestFailure {
505                    name,
506                    actual,
507                    expected,
508                    is_cst,
509                },
510            ) in self.parse_failures.iter().enumerate()
511            {
512                if expected == "NO ERROR" {
513                    writeln!(f, "\n  {}. {name}:\n", i + 1)?;
514                    writeln!(f, "  Expected an ERROR node, but got:")?;
515                    let actual = if *is_cst {
516                        actual
517                    } else {
518                        &format_sexp(actual, 2)
519                    };
520                    writeln!(
521                        f,
522                        "  {}",
523                        paint(self.color.then_some(AnsiColor::Red), actual)
524                    )?;
525                } else {
526                    writeln!(f, "\n  {}. {name}:", i + 1)?;
527                    if *is_cst {
528                        writeln!(
529                            f,
530                            "{}",
531                            TestDiff::new(actual, expected).with_color(self.color)
532                        )?;
533                    } else {
534                        writeln!(
535                            f,
536                            "{}",
537                            TestDiff::new(&format_sexp(actual, 2), &format_sexp(expected, 2))
538                                .with_color(self.color,)
539                        )?;
540                    }
541                }
542            }
543        } else {
544            writeln!(f)?;
545        }
546
547        Ok(())
548    }
549}
550
551impl std::fmt::Display for TestSummary {
552    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553        self.fmt_parse_results(f)?;
554
555        let mut render_assertion_results =
556            |name: &str, results: &TestResultHierarchy| -> std::fmt::Result {
557                writeln!(f, "{name}:")?;
558                for (depth, entry) in results.iter() {
559                    write!(f, "{}", "  ".repeat(depth + 2))?;
560                    match &entry.info {
561                        TestInfo::Group { .. } => writeln!(f, "{}", entry.name)?,
562                        TestInfo::AssertionTest { outcome, test_num } => match outcome {
563                            TestOutcome::AssertionPassed { assertion_count } => writeln!(
564                                f,
565                                "{:>3}. ✓ {} ({assertion_count} assertions)",
566                                test_num,
567                                paint(self.color.then_some(AnsiColor::Green), &entry.name)
568                            )?,
569                            TestOutcome::AssertionFailed { error } => {
570                                writeln!(
571                                    f,
572                                    "{:>3}. ✗ {}",
573                                    test_num,
574                                    paint(self.color.then_some(AnsiColor::Red), &entry.name)
575                                )?;
576                                writeln!(f, "{}  {error}", "  ".repeat(depth + 1))?;
577                            }
578                            _ => unreachable!(),
579                        },
580                        TestInfo::ParseTest { .. } => unreachable!(),
581                    }
582                }
583                Ok(())
584            };
585
586        if !self.highlight_results.root_group.is_empty() {
587            render_assertion_results("syntax highlighting", &self.highlight_results)?;
588        }
589
590        if !self.tag_results.root_group.is_empty() {
591            render_assertion_results("tags", &self.tag_results)?;
592        }
593
594        if !self.query_results.root_group.is_empty() {
595            render_assertion_results("queries", &self.query_results)?;
596        }
597
598        write!(f, "{}", self.parse_stats)?;
599
600        Ok(())
601    }
602}
603
604pub fn run_tests_at_path(
605    parser: &mut Parser,
606    opts: &TestOptions,
607    test_summary: &mut TestSummary,
608) -> Result<()> {
609    let test_entry = parse_tests(&opts.path)?;
610
611    let _log_session = if opts.debug_graph {
612        Some(util::log_graphs(parser, "log.html", opts.open_log)?)
613    } else {
614        None
615    };
616    if opts.debug {
617        parser.set_logger(Some(Box::new(|log_type, message| {
618            if log_type == LogType::Lex {
619                io::stderr().write_all(b"  ").unwrap();
620            }
621            writeln!(&mut io::stderr(), "{message}").unwrap();
622        })));
623    }
624
625    let mut corrected_entries = Vec::new();
626    run_tests(
627        parser,
628        test_entry,
629        opts,
630        test_summary,
631        &mut corrected_entries,
632        true,
633    )?;
634
635    parser.stop_printing_dot_graphs();
636
637    if test_summary.parse_failures.is_empty() || (opts.update && !test_summary.has_parse_errors) {
638        Ok(())
639    } else if opts.update && test_summary.has_parse_errors {
640        Err(anyhow!(indoc! {"
641                Some tests failed to parse with unexpected `ERROR` or `MISSING` nodes, as shown above, and cannot be updated automatically.
642                Either fix the grammar or manually update the tests if this is expected."}))
643    } else {
644        Err(anyhow!(""))
645    }
646}
647
648pub fn check_queries_at_path(language: &Language, path: &Path) -> Result<()> {
649    if path.exists() {
650        for entry in WalkDir::new(path)
651            .into_iter()
652            .filter_map(std::result::Result::ok)
653            .filter(|e| {
654                e.file_type().is_file()
655                    && e.path().extension().and_then(OsStr::to_str) == Some("scm")
656                    && !e.path().starts_with(".")
657            })
658        {
659            let filepath = entry.file_name().to_str().unwrap_or("");
660            let content = fs::read_to_string(entry.path())
661                .with_context(|| format!("Error reading query file {filepath:?}"))?;
662            Query::new(language, &content)
663                .with_context(|| format!("Error in query file {filepath:?}"))?;
664        }
665    }
666    Ok(())
667}
668
669pub struct DiffKey;
670
671impl std::fmt::Display for DiffKey {
672    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
673        write!(
674            f,
675            "\ncorrect / {} / {}",
676            paint(Some(AnsiColor::Green), "expected"),
677            paint(Some(AnsiColor::Red), "unexpected")
678        )?;
679        Ok(())
680    }
681}
682
683impl DiffKey {
684    /// Writes [`DiffKey`] to stdout
685    pub fn print() {
686        println!("{Self}");
687    }
688}
689
690pub struct TestDiff<'a> {
691    pub actual: &'a str,
692    pub expected: &'a str,
693    pub color: bool,
694}
695
696impl<'a> TestDiff<'a> {
697    #[must_use]
698    pub const fn new(actual: &'a str, expected: &'a str) -> Self {
699        Self {
700            actual,
701            expected,
702            color: true,
703        }
704    }
705
706    #[must_use]
707    pub const fn with_color(mut self, color: bool) -> Self {
708        self.color = color;
709        self
710    }
711}
712
713impl std::fmt::Display for TestDiff<'_> {
714    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
715        let diff = TextDiff::from_lines(self.actual, self.expected);
716        for diff in diff.iter_all_changes() {
717            match diff.tag() {
718                ChangeTag::Equal => {
719                    if self.color {
720                        write!(f, "{diff}")?;
721                    } else {
722                        write!(f, " {diff}")?;
723                    }
724                }
725                ChangeTag::Insert => {
726                    if self.color {
727                        write!(
728                            f,
729                            "{}",
730                            paint(Some(AnsiColor::Green), diff.as_str().unwrap())
731                        )?;
732                    } else {
733                        write!(f, "+{diff}")?;
734                    }
735                    if diff.missing_newline() {
736                        writeln!(f)?;
737                    }
738                }
739                ChangeTag::Delete => {
740                    if self.color {
741                        write!(f, "{}", paint(Some(AnsiColor::Red), diff.as_str().unwrap()))?;
742                    } else {
743                        write!(f, "-{diff}")?;
744                    }
745                    if diff.missing_newline() {
746                        writeln!(f)?;
747                    }
748                }
749            }
750        }
751
752        Ok(())
753    }
754}
755
756#[derive(Debug, Serialize, JsonSchema)]
757pub struct TestFailure {
758    name: String,
759    actual: String,
760    expected: String,
761    is_cst: bool,
762}
763
764impl TestFailure {
765    fn new<T, U, V>(name: T, actual: U, expected: V, is_cst: bool) -> Self
766    where
767        T: Into<String>,
768        U: Into<String>,
769        V: Into<String>,
770    {
771        Self {
772            name: name.into(),
773            actual: actual.into(),
774            expected: expected.into(),
775            is_cst,
776        }
777    }
778}
779
780struct TestCorrection {
781    name: String,
782    input: String,
783    output: String,
784    attributes_str: String,
785    header_delim_len: usize,
786    divider_delim_len: usize,
787}
788
789impl TestCorrection {
790    fn new<T, U, V, W>(
791        name: T,
792        input: U,
793        output: V,
794        attributes_str: W,
795        header_delim_len: usize,
796        divider_delim_len: usize,
797    ) -> Self
798    where
799        T: Into<String>,
800        U: Into<String>,
801        V: Into<String>,
802        W: Into<String>,
803    {
804        Self {
805            name: name.into(),
806            input: input.into(),
807            output: output.into(),
808            attributes_str: attributes_str.into(),
809            header_delim_len,
810            divider_delim_len,
811        }
812    }
813}
814
815/// This will return false if we want to "fail fast". It will bail and not parse any more tests.
816fn run_tests(
817    parser: &mut Parser,
818    test_entry: TestEntry,
819    opts: &TestOptions,
820    test_summary: &mut TestSummary,
821    corrected_entries: &mut Vec<TestCorrection>,
822    is_root: bool,
823) -> Result<bool> {
824    match test_entry {
825        TestEntry::Example {
826            name,
827            input,
828            output,
829            header_delim_len,
830            divider_delim_len,
831            has_fields,
832            attributes_str,
833            attributes,
834            ..
835        } => {
836            if attributes.skip {
837                test_summary.parse_results.add_case(TestResult {
838                    name: name.clone(),
839                    info: TestInfo::ParseTest {
840                        outcome: TestOutcome::Skipped,
841                        parse_rate: None,
842                        test_num: test_summary.test_num,
843                    },
844                });
845                test_summary.test_num += 1;
846                return Ok(true);
847            }
848
849            if !attributes.platform {
850                test_summary.parse_results.add_case(TestResult {
851                    name: name.clone(),
852                    info: TestInfo::ParseTest {
853                        outcome: TestOutcome::Platform,
854                        parse_rate: None,
855                        test_num: test_summary.test_num,
856                    },
857                });
858                test_summary.test_num += 1;
859                return Ok(true);
860            }
861
862            for (i, language_name) in attributes.languages.iter().enumerate() {
863                if !language_name.is_empty() {
864                    let language = opts
865                        .languages
866                        .get(language_name.as_ref())
867                        .ok_or_else(|| anyhow!("Language not found: {language_name}"))?;
868                    parser.set_language(language)?;
869                }
870                let start = std::time::Instant::now();
871                let tree = parser.parse(&input, None).unwrap();
872                let parse_rate = {
873                    let parse_time = start.elapsed();
874                    let true_parse_rate = tree.root_node().byte_range().len() as f64
875                        / (parse_time.as_nanos() as f64 / 1_000_000.0);
876                    let adj_parse_rate = adjusted_parse_rate(&tree, parse_time);
877
878                    test_summary.parse_stats.total_parses += 1;
879                    test_summary.parse_stats.total_duration += parse_time;
880                    test_summary.parse_stats.total_bytes += tree.root_node().byte_range().len();
881
882                    Some((true_parse_rate, adj_parse_rate))
883                };
884
885                if attributes.error {
886                    if tree.root_node().has_error() {
887                        test_summary.parse_results.add_case(TestResult {
888                            name: name.clone(),
889                            info: TestInfo::ParseTest {
890                                outcome: TestOutcome::Passed,
891                                parse_rate,
892                                test_num: test_summary.test_num,
893                            },
894                        });
895                        test_summary.parse_stats.successful_parses += 1;
896                        if opts.update {
897                            let input = String::from_utf8(input.clone()).unwrap();
898                            let output = if attributes.cst {
899                                output.clone()
900                            } else {
901                                format_sexp(&output, 0)
902                            };
903                            corrected_entries.push(TestCorrection::new(
904                                &name,
905                                input,
906                                output,
907                                &attributes_str,
908                                header_delim_len,
909                                divider_delim_len,
910                            ));
911                        }
912                    } else {
913                        if opts.update {
914                            let input = String::from_utf8(input.clone()).unwrap();
915                            // Keep the original `expected` output if the actual output has no error
916                            let output = if attributes.cst {
917                                output.clone()
918                            } else {
919                                format_sexp(&output, 0)
920                            };
921                            corrected_entries.push(TestCorrection::new(
922                                &name,
923                                input,
924                                output,
925                                &attributes_str,
926                                header_delim_len,
927                                divider_delim_len,
928                            ));
929                        }
930                        test_summary.parse_results.add_case(TestResult {
931                            name: name.clone(),
932                            info: TestInfo::ParseTest {
933                                outcome: TestOutcome::Failed,
934                                parse_rate,
935                                test_num: test_summary.test_num,
936                            },
937                        });
938                        let actual = if attributes.cst {
939                            render_test_cst(&input, &tree)?
940                        } else {
941                            tree.root_node().to_sexp()
942                        };
943                        test_summary.parse_failures.push(TestFailure::new(
944                            &name,
945                            actual,
946                            "NO ERROR",
947                            attributes.cst,
948                        ));
949                    }
950
951                    if attributes.fail_fast {
952                        return Ok(false);
953                    }
954                } else {
955                    let mut actual = if attributes.cst {
956                        render_test_cst(&input, &tree)?
957                    } else {
958                        tree.root_node().to_sexp()
959                    };
960                    if !(attributes.cst || opts.show_fields || has_fields) {
961                        actual = strip_sexp_fields(&actual);
962                    }
963
964                    if actual == output {
965                        test_summary.parse_results.add_case(TestResult {
966                            name: name.clone(),
967                            info: TestInfo::ParseTest {
968                                outcome: TestOutcome::Passed,
969                                parse_rate,
970                                test_num: test_summary.test_num,
971                            },
972                        });
973                        test_summary.parse_stats.successful_parses += 1;
974                        if opts.update {
975                            let input = String::from_utf8(input.clone()).unwrap();
976                            let output = if attributes.cst {
977                                actual
978                            } else {
979                                format_sexp(&output, 0)
980                            };
981                            corrected_entries.push(TestCorrection::new(
982                                &name,
983                                input,
984                                output,
985                                &attributes_str,
986                                header_delim_len,
987                                divider_delim_len,
988                            ));
989                        }
990                    } else {
991                        if opts.update {
992                            let input = String::from_utf8(input.clone()).unwrap();
993                            let (expected_output, actual_output) = if attributes.cst {
994                                (output.clone(), actual.clone())
995                            } else {
996                                (format_sexp(&output, 0), format_sexp(&actual, 0))
997                            };
998
999                            // Only bail early before updating if the actual is not the output,
1000                            // sometimes users want to test cases that
1001                            // are intended to have errors, hence why this
1002                            // check isn't shown above
1003                            if actual.contains("ERROR") || actual.contains("MISSING") {
1004                                test_summary.has_parse_errors = true;
1005
1006                                // keep the original `expected` output if the actual output has an
1007                                // error
1008                                corrected_entries.push(TestCorrection::new(
1009                                    &name,
1010                                    input,
1011                                    expected_output,
1012                                    &attributes_str,
1013                                    header_delim_len,
1014                                    divider_delim_len,
1015                                ));
1016                            } else {
1017                                corrected_entries.push(TestCorrection::new(
1018                                    &name,
1019                                    input,
1020                                    actual_output,
1021                                    &attributes_str,
1022                                    header_delim_len,
1023                                    divider_delim_len,
1024                                ));
1025                                test_summary.parse_results.add_case(TestResult {
1026                                    name: name.clone(),
1027                                    info: TestInfo::ParseTest {
1028                                        outcome: TestOutcome::Updated,
1029                                        parse_rate,
1030                                        test_num: test_summary.test_num,
1031                                    },
1032                                });
1033                            }
1034                        } else {
1035                            test_summary.parse_results.add_case(TestResult {
1036                                name: name.clone(),
1037                                info: TestInfo::ParseTest {
1038                                    outcome: TestOutcome::Failed,
1039                                    parse_rate,
1040                                    test_num: test_summary.test_num,
1041                                },
1042                            });
1043                        }
1044                        test_summary.parse_failures.push(TestFailure::new(
1045                            &name,
1046                            actual,
1047                            &output,
1048                            attributes.cst,
1049                        ));
1050
1051                        if attributes.fail_fast {
1052                            return Ok(false);
1053                        }
1054                    }
1055                }
1056
1057                if i == attributes.languages.len() - 1 {
1058                    // reset to the first language
1059                    parser.set_language(opts.languages.values().next().unwrap())?;
1060                }
1061            }
1062            test_summary.test_num += 1;
1063        }
1064        TestEntry::Group {
1065            name,
1066            children,
1067            file_path,
1068        } => {
1069            if children.is_empty() {
1070                return Ok(true);
1071            }
1072
1073            let failure_count = test_summary.parse_failures.len();
1074            let mut ran_test_in_group = false;
1075
1076            let matches_filter = |name: &str, file_name: &Option<String>, opts: &TestOptions| {
1077                if let (Some(test_file_path), Some(filter_file_name)) = (file_name, &opts.file_name)
1078                {
1079                    if !filter_file_name.eq(test_file_path) {
1080                        return false;
1081                    }
1082                }
1083                if let Some(include) = &opts.include {
1084                    include.is_match(name)
1085                } else if let Some(exclude) = &opts.exclude {
1086                    !exclude.is_match(name)
1087                } else {
1088                    true
1089                }
1090            };
1091
1092            for child in children {
1093                if let TestEntry::Example {
1094                    ref name,
1095                    ref file_name,
1096                    ref input,
1097                    ref output,
1098                    ref attributes_str,
1099                    header_delim_len,
1100                    divider_delim_len,
1101                    ..
1102                } = child
1103                {
1104                    if !matches_filter(name, file_name, opts) {
1105                        if opts.update {
1106                            let input = String::from_utf8(input.clone()).unwrap();
1107                            let output = format_sexp(output, 0);
1108                            corrected_entries.push(TestCorrection::new(
1109                                name,
1110                                input,
1111                                output,
1112                                attributes_str,
1113                                header_delim_len,
1114                                divider_delim_len,
1115                            ));
1116                        }
1117
1118                        test_summary.test_num += 1;
1119                        continue;
1120                    }
1121                }
1122
1123                if !ran_test_in_group && !is_root {
1124                    test_summary.parse_results.add_group(&name);
1125                    ran_test_in_group = true;
1126                }
1127                if !run_tests(parser, child, opts, test_summary, corrected_entries, false)? {
1128                    // fail fast
1129                    return Ok(false);
1130                }
1131            }
1132            // Now that we're done traversing the children of the current group, pop
1133            // the index
1134            test_summary.parse_results.pop_traversal();
1135
1136            if let Some(file_path) = file_path {
1137                if opts.update && test_summary.parse_failures.len() - failure_count > 0 {
1138                    write_tests(&file_path, corrected_entries)?;
1139                }
1140                corrected_entries.clear();
1141            }
1142        }
1143    }
1144    Ok(true)
1145}
1146
1147/// Convenience wrapper to render a CST for a test entry.
1148fn render_test_cst(input: &[u8], tree: &Tree) -> Result<String> {
1149    let mut rendered_cst: Vec<u8> = Vec::new();
1150    let mut cursor = tree.walk();
1151    let opts = ParseFileOptions {
1152        edits: &[],
1153        output: ParseOutput::Cst,
1154        stats: &mut ParseStats::default(),
1155        print_time: false,
1156        timeout: 0,
1157        debug: ParseDebugType::Quiet,
1158        debug_graph: false,
1159        cancellation_flag: None,
1160        encoding: None,
1161        open_log: false,
1162        no_ranges: false,
1163        parse_theme: &ParseTheme::empty(),
1164    };
1165    render_cst(input, tree, &mut cursor, &opts, &mut rendered_cst)?;
1166    Ok(String::from_utf8_lossy(&rendered_cst).trim().to_string())
1167}
1168
1169// Parse time is interpreted in ns before converting to ms to avoid truncation issues
1170// Parse rates often have several outliers, leading to a large standard deviation. Taking
1171// the log of these rates serves to "flatten" out the distribution, yielding a more
1172// usable standard deviation for finding statistically significant slow parse rates
1173// NOTE: This is just a heuristic
1174#[must_use]
1175pub fn adjusted_parse_rate(tree: &Tree, parse_time: Duration) -> f64 {
1176    f64::ln(
1177        tree.root_node().byte_range().len() as f64 / (parse_time.as_nanos() as f64 / 1_000_000.0),
1178    )
1179}
1180
1181fn write_tests(file_path: &Path, corrected_entries: &[TestCorrection]) -> Result<()> {
1182    let mut buffer = fs::File::create(file_path)?;
1183    write_tests_to_buffer(&mut buffer, corrected_entries)
1184}
1185
1186fn write_tests_to_buffer(
1187    buffer: &mut impl Write,
1188    corrected_entries: &[TestCorrection],
1189) -> Result<()> {
1190    for (
1191        i,
1192        TestCorrection {
1193            name,
1194            input,
1195            output,
1196            attributes_str,
1197            header_delim_len,
1198            divider_delim_len,
1199        },
1200    ) in corrected_entries.iter().enumerate()
1201    {
1202        if i > 0 {
1203            writeln!(buffer)?;
1204        }
1205        writeln!(
1206            buffer,
1207            "{}\n{name}\n{}{}\n{input}\n{}\n\n{}",
1208            "=".repeat(*header_delim_len),
1209            if attributes_str.is_empty() {
1210                attributes_str.clone()
1211            } else {
1212                format!("{attributes_str}\n")
1213            },
1214            "=".repeat(*header_delim_len),
1215            "-".repeat(*divider_delim_len),
1216            output.trim()
1217        )?;
1218    }
1219    Ok(())
1220}
1221
1222pub fn parse_tests(path: &Path) -> io::Result<TestEntry> {
1223    let name = path
1224        .file_stem()
1225        .and_then(|s| s.to_str())
1226        .unwrap_or("")
1227        .to_string();
1228    if path.is_dir() {
1229        let mut children = Vec::new();
1230        for entry in fs::read_dir(path)? {
1231            let entry = entry?;
1232            let hidden = entry.file_name().to_str().unwrap_or("").starts_with('.');
1233            if !hidden {
1234                children.push(entry.path());
1235            }
1236        }
1237        children.sort_by(|a, b| {
1238            a.file_name()
1239                .unwrap_or_default()
1240                .cmp(b.file_name().unwrap_or_default())
1241        });
1242        let children = children
1243            .iter()
1244            .map(|path| parse_tests(path))
1245            .collect::<io::Result<Vec<TestEntry>>>()?;
1246        Ok(TestEntry::Group {
1247            name,
1248            children,
1249            file_path: None,
1250        })
1251    } else {
1252        let content = fs::read_to_string(path)?;
1253        Ok(parse_test_content(name, &content, Some(path.to_path_buf())))
1254    }
1255}
1256
1257#[must_use]
1258pub fn strip_sexp_fields(sexp: &str) -> String {
1259    SEXP_FIELD_REGEX.replace_all(sexp, " (").to_string()
1260}
1261
1262#[must_use]
1263pub fn strip_points(sexp: &str) -> String {
1264    POINT_REGEX.replace_all(sexp, "").to_string()
1265}
1266
1267fn parse_test_content(name: String, content: &str, file_path: Option<PathBuf>) -> TestEntry {
1268    let mut children = Vec::new();
1269    let bytes = content.as_bytes();
1270    let mut prev_name = String::new();
1271    let mut prev_attributes_str = String::new();
1272    let mut prev_header_end = 0;
1273
1274    // Find the first test header in the file, and determine if it has a
1275    // custom suffix. If so, then this suffix will be used to identify
1276    // all subsequent headers and divider lines in the file.
1277    let first_suffix = HEADER_REGEX
1278        .captures(bytes)
1279        .and_then(|c| c.name("suffix1"))
1280        .map(|m| String::from_utf8_lossy(m.as_bytes()));
1281
1282    // Find all of the `===` test headers, which contain the test names.
1283    // Ignore any matches whose suffix does not match the first header
1284    // suffix in the file.
1285    let header_matches = HEADER_REGEX.captures_iter(bytes).filter_map(|c| {
1286        let header_delim_len = c.name("equals").map_or(80, |m| m.as_bytes().len());
1287        let suffix1 = c
1288            .name("suffix1")
1289            .map(|m| String::from_utf8_lossy(m.as_bytes()));
1290        let suffix2 = c
1291            .name("suffix2")
1292            .map(|m| String::from_utf8_lossy(m.as_bytes()));
1293
1294        let (mut skip, mut platform, mut fail_fast, mut error, mut cst, mut languages) =
1295            (false, None, false, false, false, vec![]);
1296
1297        let test_name_and_markers = c
1298            .name("test_name_and_markers")
1299            .map_or("".as_bytes(), |m| m.as_bytes());
1300
1301        let mut test_name = String::new();
1302        let mut attributes_str = String::new();
1303
1304        let mut seen_marker = false;
1305
1306        let test_name_and_markers = str::from_utf8(test_name_and_markers).unwrap();
1307        for line in test_name_and_markers
1308            .split_inclusive('\n')
1309            .filter(|s| !s.is_empty())
1310        {
1311            let trimmed_line = line.trim();
1312            match trimmed_line.split('(').next().unwrap() {
1313                ":skip" => (seen_marker, skip) = (true, true),
1314                ":platform" => {
1315                    if let Some(platforms) = trimmed_line.strip_prefix(':').and_then(|s| {
1316                        s.strip_prefix("platform(")
1317                            .and_then(|s| s.strip_suffix(')'))
1318                    }) {
1319                        seen_marker = true;
1320                        platform = Some(
1321                            platform.unwrap_or(false) || platforms.trim() == std::env::consts::OS,
1322                        );
1323                    }
1324                }
1325                ":fail-fast" => (seen_marker, fail_fast) = (true, true),
1326                ":error" => (seen_marker, error) = (true, true),
1327                ":language" => {
1328                    if let Some(lang) = trimmed_line.strip_prefix(':').and_then(|s| {
1329                        s.strip_prefix("language(")
1330                            .and_then(|s| s.strip_suffix(')'))
1331                    }) {
1332                        seen_marker = true;
1333                        languages.push(lang.into());
1334                    }
1335                }
1336                ":cst" => (seen_marker, cst) = (true, true),
1337                _ if !seen_marker => {
1338                    test_name.push_str(line);
1339                }
1340                _ => {}
1341            }
1342        }
1343        attributes_str.push_str(test_name_and_markers.strip_prefix(&test_name).unwrap());
1344
1345        // prefer skip over error, both shouldn't be set
1346        if skip {
1347            error = false;
1348        }
1349
1350        // add a default language if none are specified, will defer to the first language
1351        if languages.is_empty() {
1352            languages.push("".into());
1353        }
1354
1355        if suffix1 == first_suffix && suffix2 == first_suffix {
1356            let header_range = c.get(0).unwrap().range();
1357            let test_name = if test_name.is_empty() {
1358                None
1359            } else {
1360                Some(test_name.trim_end().to_string())
1361            };
1362            let attributes_str = if attributes_str.is_empty() {
1363                None
1364            } else {
1365                Some(attributes_str.trim_end().to_string())
1366            };
1367            Some((
1368                header_delim_len,
1369                header_range,
1370                test_name,
1371                attributes_str,
1372                TestAttributes {
1373                    skip,
1374                    platform: platform.unwrap_or(true),
1375                    fail_fast,
1376                    error,
1377                    cst,
1378                    languages,
1379                },
1380            ))
1381        } else {
1382            None
1383        }
1384    });
1385
1386    let (mut prev_header_len, mut prev_attributes) = (80, TestAttributes::default());
1387    for (header_delim_len, header_range, test_name, attributes_str, attributes) in header_matches
1388        .chain(Some((
1389            80,
1390            bytes.len()..bytes.len(),
1391            None,
1392            None,
1393            TestAttributes::default(),
1394        )))
1395    {
1396        // Find the longest line of dashes following each test description. That line
1397        // separates the input from the expected output. Ignore any matches whose suffix
1398        // does not match the first suffix in the file.
1399        if prev_header_end > 0 {
1400            let divider_range = DIVIDER_REGEX
1401                .captures_iter(&bytes[prev_header_end..header_range.start])
1402                .filter_map(|m| {
1403                    let divider_delim_len = m.name("hyphens").map_or(80, |m| m.as_bytes().len());
1404                    let suffix = m
1405                        .name("suffix")
1406                        .map(|m| String::from_utf8_lossy(m.as_bytes()));
1407                    if suffix == first_suffix {
1408                        let range = m.get(0).unwrap().range();
1409                        Some((
1410                            divider_delim_len,
1411                            (prev_header_end + range.start)..(prev_header_end + range.end),
1412                        ))
1413                    } else {
1414                        None
1415                    }
1416                })
1417                .max_by_key(|(_, range)| range.len());
1418
1419            if let Some((divider_delim_len, divider_range)) = divider_range {
1420                if let Ok(output) = str::from_utf8(&bytes[divider_range.end..header_range.start]) {
1421                    let mut input = bytes[prev_header_end..divider_range.start].to_vec();
1422
1423                    // Remove trailing newline from the input.
1424                    input.pop();
1425                    if input.last() == Some(&b'\r') {
1426                        input.pop();
1427                    }
1428
1429                    let (output, has_fields) = if prev_attributes.cst {
1430                        (output.trim().to_string(), false)
1431                    } else {
1432                        // Remove all comments
1433                        let output = COMMENT_REGEX.replace_all(output, "").to_string();
1434
1435                        // Normalize the whitespace in the expected output.
1436                        let output = WHITESPACE_REGEX.replace_all(output.trim(), " ");
1437                        let output = output.replace(" )", ")");
1438
1439                        // Identify if the expected output has fields indicated. If not, then
1440                        // fields will not be checked.
1441                        let has_fields = SEXP_FIELD_REGEX.is_match(&output);
1442
1443                        (output, has_fields)
1444                    };
1445
1446                    let file_name = if let Some(ref path) = file_path {
1447                        path.file_name().map(|n| n.to_string_lossy().to_string())
1448                    } else {
1449                        None
1450                    };
1451
1452                    let t = TestEntry::Example {
1453                        name: prev_name,
1454                        input,
1455                        output,
1456                        header_delim_len: prev_header_len,
1457                        divider_delim_len,
1458                        has_fields,
1459                        attributes_str: prev_attributes_str,
1460                        attributes: prev_attributes,
1461                        file_name,
1462                    };
1463
1464                    children.push(t);
1465                }
1466            }
1467        }
1468        prev_attributes = attributes;
1469        prev_name = test_name.unwrap_or_default();
1470        prev_attributes_str = attributes_str.unwrap_or_default();
1471        prev_header_len = header_delim_len;
1472        prev_header_end = header_range.end;
1473    }
1474    TestEntry::Group {
1475        name,
1476        children,
1477        file_path,
1478    }
1479}
1480
1481#[cfg(test)]
1482mod tests {
1483    use serde_json::json;
1484
1485    use crate::tests::get_language;
1486
1487    use super::*;
1488
1489    #[test]
1490    fn test_parse_test_content_simple() {
1491        let entry = parse_test_content(
1492            "the-filename".to_string(),
1493            r"
1494===============
1495The first test
1496===============
1497
1498a b c
1499
1500---
1501
1502(a
1503    (b c))
1504
1505================
1506The second test
1507================
1508d
1509---
1510(d)
1511        "
1512            .trim(),
1513            None,
1514        );
1515
1516        assert_eq!(
1517            entry,
1518            TestEntry::Group {
1519                name: "the-filename".to_string(),
1520                children: vec![
1521                    TestEntry::Example {
1522                        name: "The first test".to_string(),
1523                        input: b"\na b c\n".to_vec(),
1524                        output: "(a (b c))".to_string(),
1525                        header_delim_len: 15,
1526                        divider_delim_len: 3,
1527                        has_fields: false,
1528                        attributes_str: String::new(),
1529                        attributes: TestAttributes::default(),
1530                        file_name: None,
1531                    },
1532                    TestEntry::Example {
1533                        name: "The second test".to_string(),
1534                        input: b"d".to_vec(),
1535                        output: "(d)".to_string(),
1536                        header_delim_len: 16,
1537                        divider_delim_len: 3,
1538                        has_fields: false,
1539                        attributes_str: String::new(),
1540                        attributes: TestAttributes::default(),
1541                        file_name: None,
1542                    },
1543                ],
1544                file_path: None,
1545            }
1546        );
1547    }
1548
1549    #[test]
1550    fn test_parse_test_content_with_dashes_in_source_code() {
1551        let entry = parse_test_content(
1552            "the-filename".to_string(),
1553            r"
1554==================
1555Code with dashes
1556==================
1557abc
1558---
1559defg
1560----
1561hijkl
1562-------
1563
1564(a (b))
1565
1566=========================
1567Code ending with dashes
1568=========================
1569abc
1570-----------
1571-------------------
1572
1573(c (d))
1574        "
1575            .trim(),
1576            None,
1577        );
1578
1579        assert_eq!(
1580            entry,
1581            TestEntry::Group {
1582                name: "the-filename".to_string(),
1583                children: vec![
1584                    TestEntry::Example {
1585                        name: "Code with dashes".to_string(),
1586                        input: b"abc\n---\ndefg\n----\nhijkl".to_vec(),
1587                        output: "(a (b))".to_string(),
1588                        header_delim_len: 18,
1589                        divider_delim_len: 7,
1590                        has_fields: false,
1591                        attributes_str: String::new(),
1592                        attributes: TestAttributes::default(),
1593                        file_name: None,
1594                    },
1595                    TestEntry::Example {
1596                        name: "Code ending with dashes".to_string(),
1597                        input: b"abc\n-----------".to_vec(),
1598                        output: "(c (d))".to_string(),
1599                        header_delim_len: 25,
1600                        divider_delim_len: 19,
1601                        has_fields: false,
1602                        attributes_str: String::new(),
1603                        attributes: TestAttributes::default(),
1604                        file_name: None,
1605                    },
1606                ],
1607                file_path: None,
1608            }
1609        );
1610    }
1611
1612    #[test]
1613    fn test_format_sexp() {
1614        assert_eq!(format_sexp("", 0), "");
1615        assert_eq!(
1616            format_sexp("(a b: (c) (d) e: (f (g (h (MISSING i)))))", 0),
1617            r"
1618(a
1619  b: (c)
1620  (d)
1621  e: (f
1622    (g
1623      (h
1624        (MISSING i)))))
1625"
1626            .trim()
1627        );
1628        assert_eq!(
1629            format_sexp("(program (ERROR (UNEXPECTED ' ')) (identifier))", 0),
1630            r"
1631(program
1632  (ERROR
1633    (UNEXPECTED ' '))
1634  (identifier))
1635"
1636            .trim()
1637        );
1638        assert_eq!(
1639            format_sexp(r#"(source_file (MISSING ")"))"#, 0),
1640            r#"
1641(source_file
1642  (MISSING ")"))
1643        "#
1644            .trim()
1645        );
1646        assert_eq!(
1647            format_sexp(
1648                r"(source_file (ERROR (UNEXPECTED 'f') (UNEXPECTED '+')))",
1649                0
1650            ),
1651            r"
1652(source_file
1653  (ERROR
1654    (UNEXPECTED 'f')
1655    (UNEXPECTED '+')))
1656"
1657            .trim()
1658        );
1659    }
1660
1661    #[test]
1662    fn test_write_tests_to_buffer() {
1663        let mut buffer = Vec::new();
1664        let corrected_entries = vec![
1665            TestCorrection::new(
1666                "title 1".to_string(),
1667                "input 1".to_string(),
1668                "output 1".to_string(),
1669                String::new(),
1670                80,
1671                80,
1672            ),
1673            TestCorrection::new(
1674                "title 2".to_string(),
1675                "input 2".to_string(),
1676                "output 2".to_string(),
1677                String::new(),
1678                80,
1679                80,
1680            ),
1681        ];
1682        write_tests_to_buffer(&mut buffer, &corrected_entries).unwrap();
1683        assert_eq!(
1684            String::from_utf8(buffer).unwrap(),
1685            r"
1686================================================================================
1687title 1
1688================================================================================
1689input 1
1690--------------------------------------------------------------------------------
1691
1692output 1
1693
1694================================================================================
1695title 2
1696================================================================================
1697input 2
1698--------------------------------------------------------------------------------
1699
1700output 2
1701"
1702            .trim_start()
1703            .to_string()
1704        );
1705    }
1706
1707    #[test]
1708    fn test_parse_test_content_with_comments_in_sexp() {
1709        let entry = parse_test_content(
1710            "the-filename".to_string(),
1711            r#"
1712==================
1713sexp with comment
1714==================
1715code
1716---
1717
1718; Line start comment
1719(a (b))
1720
1721==================
1722sexp with comment between
1723==================
1724code
1725---
1726
1727; Line start comment
1728(a
1729; ignore this
1730    (b)
1731    ; also ignore this
1732)
1733
1734=========================
1735sexp with ';'
1736=========================
1737code
1738---
1739
1740(MISSING ";")
1741        "#
1742            .trim(),
1743            None,
1744        );
1745
1746        assert_eq!(
1747            entry,
1748            TestEntry::Group {
1749                name: "the-filename".to_string(),
1750                children: vec![
1751                    TestEntry::Example {
1752                        name: "sexp with comment".to_string(),
1753                        input: b"code".to_vec(),
1754                        output: "(a (b))".to_string(),
1755                        header_delim_len: 18,
1756                        divider_delim_len: 3,
1757                        has_fields: false,
1758                        attributes_str: String::new(),
1759                        attributes: TestAttributes::default(),
1760                        file_name: None,
1761                    },
1762                    TestEntry::Example {
1763                        name: "sexp with comment between".to_string(),
1764                        input: b"code".to_vec(),
1765                        output: "(a (b))".to_string(),
1766                        header_delim_len: 18,
1767                        divider_delim_len: 3,
1768                        has_fields: false,
1769                        attributes_str: String::new(),
1770                        attributes: TestAttributes::default(),
1771                        file_name: None,
1772                    },
1773                    TestEntry::Example {
1774                        name: "sexp with ';'".to_string(),
1775                        input: b"code".to_vec(),
1776                        output: "(MISSING \";\")".to_string(),
1777                        header_delim_len: 25,
1778                        divider_delim_len: 3,
1779                        has_fields: false,
1780                        attributes_str: String::new(),
1781                        attributes: TestAttributes::default(),
1782                        file_name: None,
1783                    }
1784                ],
1785                file_path: None,
1786            }
1787        );
1788    }
1789
1790    #[test]
1791    fn test_parse_test_content_with_suffixes() {
1792        let entry = parse_test_content(
1793            "the-filename".to_string(),
1794            r"
1795==================asdf\()[]|{}*+?^$.-
1796First test
1797==================asdf\()[]|{}*+?^$.-
1798
1799=========================
1800NOT A TEST HEADER
1801=========================
1802-------------------------
1803
1804---asdf\()[]|{}*+?^$.-
1805
1806(a)
1807
1808==================asdf\()[]|{}*+?^$.-
1809Second test
1810==================asdf\()[]|{}*+?^$.-
1811
1812=========================
1813NOT A TEST HEADER
1814=========================
1815-------------------------
1816
1817---asdf\()[]|{}*+?^$.-
1818
1819(a)
1820
1821=========================asdf\()[]|{}*+?^$.-
1822Test name with = symbol
1823=========================asdf\()[]|{}*+?^$.-
1824
1825=========================
1826NOT A TEST HEADER
1827=========================
1828-------------------------
1829
1830---asdf\()[]|{}*+?^$.-
1831
1832(a)
1833
1834==============================asdf\()[]|{}*+?^$.-
1835Test containing equals
1836==============================asdf\()[]|{}*+?^$.-
1837
1838===
1839
1840------------------------------asdf\()[]|{}*+?^$.-
1841
1842(a)
1843
1844==============================asdf\()[]|{}*+?^$.-
1845Subsequent test containing equals
1846==============================asdf\()[]|{}*+?^$.-
1847
1848===
1849
1850------------------------------asdf\()[]|{}*+?^$.-
1851
1852(a)
1853"
1854            .trim(),
1855            None,
1856        );
1857
1858        let expected_input = b"\n=========================\n\
1859            NOT A TEST HEADER\n\
1860            =========================\n\
1861            -------------------------\n"
1862            .to_vec();
1863        pretty_assertions::assert_eq!(
1864            entry,
1865            TestEntry::Group {
1866                name: "the-filename".to_string(),
1867                children: vec![
1868                    TestEntry::Example {
1869                        name: "First test".to_string(),
1870                        input: expected_input.clone(),
1871                        output: "(a)".to_string(),
1872                        header_delim_len: 18,
1873                        divider_delim_len: 3,
1874                        has_fields: false,
1875                        attributes_str: String::new(),
1876                        attributes: TestAttributes::default(),
1877                        file_name: None,
1878                    },
1879                    TestEntry::Example {
1880                        name: "Second test".to_string(),
1881                        input: expected_input.clone(),
1882                        output: "(a)".to_string(),
1883                        header_delim_len: 18,
1884                        divider_delim_len: 3,
1885                        has_fields: false,
1886                        attributes_str: String::new(),
1887                        attributes: TestAttributes::default(),
1888                        file_name: None,
1889                    },
1890                    TestEntry::Example {
1891                        name: "Test name with = symbol".to_string(),
1892                        input: expected_input,
1893                        output: "(a)".to_string(),
1894                        header_delim_len: 25,
1895                        divider_delim_len: 3,
1896                        has_fields: false,
1897                        attributes_str: String::new(),
1898                        attributes: TestAttributes::default(),
1899                        file_name: None,
1900                    },
1901                    TestEntry::Example {
1902                        name: "Test containing equals".to_string(),
1903                        input: "\n===\n".into(),
1904                        output: "(a)".into(),
1905                        header_delim_len: 30,
1906                        divider_delim_len: 30,
1907                        has_fields: false,
1908                        attributes_str: String::new(),
1909                        attributes: TestAttributes::default(),
1910                        file_name: None,
1911                    },
1912                    TestEntry::Example {
1913                        name: "Subsequent test containing equals".to_string(),
1914                        input: "\n===\n".into(),
1915                        output: "(a)".into(),
1916                        header_delim_len: 30,
1917                        divider_delim_len: 30,
1918                        has_fields: false,
1919                        attributes_str: String::new(),
1920                        attributes: TestAttributes::default(),
1921                        file_name: None,
1922                    }
1923                ],
1924                file_path: None,
1925            }
1926        );
1927    }
1928
1929    #[test]
1930    fn test_parse_test_content_with_newlines_in_test_names() {
1931        let entry = parse_test_content(
1932            "the-filename".to_string(),
1933            r"
1934===============
1935name
1936with
1937newlines
1938===============
1939a
1940---
1941(b)
1942
1943====================
1944name with === signs
1945====================
1946code with ----
1947---
1948(d)
1949",
1950            None,
1951        );
1952
1953        assert_eq!(
1954            entry,
1955            TestEntry::Group {
1956                name: "the-filename".to_string(),
1957                file_path: None,
1958                children: vec![
1959                    TestEntry::Example {
1960                        name: "name\nwith\nnewlines".to_string(),
1961                        input: b"a".to_vec(),
1962                        output: "(b)".to_string(),
1963                        header_delim_len: 15,
1964                        divider_delim_len: 3,
1965                        has_fields: false,
1966                        attributes_str: String::new(),
1967                        attributes: TestAttributes::default(),
1968                        file_name: None,
1969                    },
1970                    TestEntry::Example {
1971                        name: "name with === signs".to_string(),
1972                        input: b"code with ----".to_vec(),
1973                        output: "(d)".to_string(),
1974                        header_delim_len: 20,
1975                        divider_delim_len: 3,
1976                        has_fields: false,
1977                        attributes_str: String::new(),
1978                        attributes: TestAttributes::default(),
1979                        file_name: None,
1980                    }
1981                ]
1982            }
1983        );
1984    }
1985
1986    #[test]
1987    fn test_parse_test_with_markers() {
1988        // do one with :skip, we should not see it in the entry output
1989
1990        let entry = parse_test_content(
1991            "the-filename".to_string(),
1992            r"
1993=====================
1994Test with skip marker
1995:skip
1996=====================
1997a
1998---
1999(b)
2000",
2001            None,
2002        );
2003
2004        assert_eq!(
2005            entry,
2006            TestEntry::Group {
2007                name: "the-filename".to_string(),
2008                file_path: None,
2009                children: vec![TestEntry::Example {
2010                    name: "Test with skip marker".to_string(),
2011                    input: b"a".to_vec(),
2012                    output: "(b)".to_string(),
2013                    header_delim_len: 21,
2014                    divider_delim_len: 3,
2015                    has_fields: false,
2016                    attributes_str: ":skip".to_string(),
2017                    attributes: TestAttributes {
2018                        skip: true,
2019                        platform: true,
2020                        fail_fast: false,
2021                        error: false,
2022                        cst: false,
2023                        languages: vec!["".into()]
2024                    },
2025                    file_name: None,
2026                }]
2027            }
2028        );
2029
2030        let entry = parse_test_content(
2031            "the-filename".to_string(),
2032            &format!(
2033                r"
2034=========================
2035Test with platform marker
2036:platform({})
2037:fail-fast
2038=========================
2039a
2040---
2041(b)
2042
2043=============================
2044Test with bad platform marker
2045:platform({})
2046
2047:language(foo)
2048=============================
2049a
2050---
2051(b)
2052
2053====================
2054Test with cst marker
2055:cst
2056====================
20571
2058---
20590:0 - 1:0   source_file
20600:0 - 0:1   expression
20610:0 - 0:1     number_literal `1`
2062",
2063                std::env::consts::OS,
2064                if std::env::consts::OS == "linux" {
2065                    "macos"
2066                } else {
2067                    "linux"
2068                }
2069            ),
2070            None,
2071        );
2072
2073        assert_eq!(
2074            entry,
2075            TestEntry::Group {
2076                name: "the-filename".to_string(),
2077                file_path: None,
2078                children: vec![
2079                    TestEntry::Example {
2080                        name: "Test with platform marker".to_string(),
2081                        input: b"a".to_vec(),
2082                        output: "(b)".to_string(),
2083                        header_delim_len: 25,
2084                        divider_delim_len: 3,
2085                        has_fields: false,
2086                        attributes_str: format!(":platform({})\n:fail-fast", std::env::consts::OS),
2087                        attributes: TestAttributes {
2088                            skip: false,
2089                            platform: true,
2090                            fail_fast: true,
2091                            error: false,
2092                            cst: false,
2093                            languages: vec!["".into()]
2094                        },
2095                        file_name: None,
2096                    },
2097                    TestEntry::Example {
2098                        name: "Test with bad platform marker".to_string(),
2099                        input: b"a".to_vec(),
2100                        output: "(b)".to_string(),
2101                        header_delim_len: 29,
2102                        divider_delim_len: 3,
2103                        has_fields: false,
2104                        attributes_str: if std::env::consts::OS == "linux" {
2105                            ":platform(macos)\n\n:language(foo)".to_string()
2106                        } else {
2107                            ":platform(linux)\n\n:language(foo)".to_string()
2108                        },
2109                        attributes: TestAttributes {
2110                            skip: false,
2111                            platform: false,
2112                            fail_fast: false,
2113                            error: false,
2114                            cst: false,
2115                            languages: vec!["foo".into()]
2116                        },
2117                        file_name: None,
2118                    },
2119                    TestEntry::Example {
2120                        name: "Test with cst marker".to_string(),
2121                        input: b"1".to_vec(),
2122                        output: "0:0 - 1:0   source_file
21230:0 - 0:1   expression
21240:0 - 0:1     number_literal `1`"
2125                            .to_string(),
2126                        header_delim_len: 20,
2127                        divider_delim_len: 3,
2128                        has_fields: false,
2129                        attributes_str: ":cst".to_string(),
2130                        attributes: TestAttributes {
2131                            skip: false,
2132                            platform: true,
2133                            fail_fast: false,
2134                            error: false,
2135                            cst: true,
2136                            languages: vec!["".into()]
2137                        },
2138                        file_name: None,
2139                    }
2140                ]
2141            }
2142        );
2143    }
2144
2145    fn clear_parse_rate(result: &mut TestResult) {
2146        let test_case_info = &mut result.info;
2147        match test_case_info {
2148            TestInfo::ParseTest {
2149                ref mut parse_rate, ..
2150            } => {
2151                assert!(parse_rate.is_some());
2152                *parse_rate = None;
2153            }
2154            TestInfo::Group { .. } | TestInfo::AssertionTest { .. } => {
2155                panic!("Unexpected test result")
2156            }
2157        }
2158    }
2159
2160    #[test]
2161    fn run_tests_simple() {
2162        let mut parser = Parser::new();
2163        let language = get_language("c");
2164        parser
2165            .set_language(&language)
2166            .expect("Failed to set language");
2167        let mut languages = BTreeMap::new();
2168        languages.insert("c", &language);
2169        let opts = TestOptions {
2170            path: PathBuf::from("foo"),
2171            debug: true,
2172            debug_graph: false,
2173            include: None,
2174            exclude: None,
2175            file_name: None,
2176            update: false,
2177            open_log: false,
2178            languages,
2179            color: true,
2180            show_fields: false,
2181            overview_only: false,
2182        };
2183
2184        // NOTE: The following test cases are combined to work around a race condition
2185        // in the loader
2186        {
2187            let test_entry = TestEntry::Group {
2188                name: "foo".to_string(),
2189                file_path: None,
2190                children: vec![TestEntry::Example {
2191                    name: "C Test 1".to_string(),
2192                    input: b"1;\n".to_vec(),
2193                    output: "(translation_unit (expression_statement (number_literal)))"
2194                        .to_string(),
2195                    header_delim_len: 25,
2196                    divider_delim_len: 3,
2197                    has_fields: false,
2198                    attributes_str: String::new(),
2199                    attributes: TestAttributes::default(),
2200                    file_name: None,
2201                }],
2202            };
2203
2204            let mut test_summary = TestSummary::new(true, TestStats::All, false, false, false);
2205            let mut corrected_entries = Vec::new();
2206            run_tests(
2207                &mut parser,
2208                test_entry,
2209                &opts,
2210                &mut test_summary,
2211                &mut corrected_entries,
2212                true,
2213            )
2214            .expect("Failed to run tests");
2215
2216            // parse rates will always be different, so we need to clear out these
2217            // fields to reliably assert equality below
2218            clear_parse_rate(&mut test_summary.parse_results.root_group[0]);
2219            test_summary.parse_stats.total_duration = Duration::from_secs(0);
2220
2221            let json_results = serde_json::to_string(&test_summary).unwrap();
2222
2223            assert_eq!(
2224                json_results,
2225                json!({
2226                  "parse_results": [
2227                    {
2228                      "name": "C Test 1",
2229                      "outcome": "Passed",
2230                      "parse_rate": null,
2231                      "test_num": 1
2232                    }
2233                  ],
2234                  "parse_failures": [],
2235                  "parse_stats": {
2236                    "successful_parses": 1,
2237                    "total_parses": 1,
2238                    "total_bytes": 3,
2239                    "total_duration": {
2240                      "secs": 0,
2241                      "nanos": 0,
2242                    }
2243                  },
2244                  "highlight_results": [],
2245                  "tag_results": [],
2246                  "query_results": []
2247                })
2248                .to_string()
2249            );
2250        }
2251        {
2252            let test_entry = TestEntry::Group {
2253                name: "corpus".to_string(),
2254                file_path: None,
2255                children: vec![
2256                    TestEntry::Group {
2257                        name: "group1".to_string(),
2258                        // This test passes
2259                        children: vec![TestEntry::Example {
2260                            name: "C Test 1".to_string(),
2261                            input: b"1;\n".to_vec(),
2262                            output: "(translation_unit (expression_statement (number_literal)))"
2263                                .to_string(),
2264                            header_delim_len: 25,
2265                            divider_delim_len: 3,
2266                            has_fields: false,
2267                            attributes_str: String::new(),
2268                            attributes: TestAttributes::default(),
2269                            file_name: None,
2270                        }],
2271                        file_path: None,
2272                    },
2273                    TestEntry::Group {
2274                        name: "group2".to_string(),
2275                        children: vec![
2276                            // This test passes
2277                            TestEntry::Example {
2278                                name: "C Test 2".to_string(),
2279                                input: b"1;\n".to_vec(),
2280                                output:
2281                                    "(translation_unit (expression_statement (number_literal)))"
2282                                        .to_string(),
2283                                header_delim_len: 25,
2284                                divider_delim_len: 3,
2285                                has_fields: false,
2286                                attributes_str: String::new(),
2287                                attributes: TestAttributes::default(),
2288                                file_name: None,
2289                            },
2290                            // This test fails, and is marked with fail-fast
2291                            TestEntry::Example {
2292                                name: "C Test 3".to_string(),
2293                                input: b"1;\n".to_vec(),
2294                                output:
2295                                    "(translation_unit (expression_statement (string_literal)))"
2296                                        .to_string(),
2297                                header_delim_len: 25,
2298                                divider_delim_len: 3,
2299                                has_fields: false,
2300                                attributes_str: String::new(),
2301                                attributes: TestAttributes {
2302                                    fail_fast: true,
2303                                    ..Default::default()
2304                                },
2305                                file_name: None,
2306                            },
2307                        ],
2308                        file_path: None,
2309                    },
2310                    // This group never runs because of the previous failure
2311                    TestEntry::Group {
2312                        name: "group3".to_string(),
2313                        // This test fails, and is marked with fail-fast
2314                        children: vec![TestEntry::Example {
2315                            name: "C Test 4".to_string(),
2316                            input: b"1;\n".to_vec(),
2317                            output: "(translation_unit (expression_statement (number_literal)))"
2318                                .to_string(),
2319                            header_delim_len: 25,
2320                            divider_delim_len: 3,
2321                            has_fields: false,
2322                            attributes_str: String::new(),
2323                            attributes: TestAttributes::default(),
2324                            file_name: None,
2325                        }],
2326                        file_path: None,
2327                    },
2328                ],
2329            };
2330
2331            let mut test_summary = TestSummary::new(true, TestStats::All, false, false, false);
2332            let mut corrected_entries = Vec::new();
2333            run_tests(
2334                &mut parser,
2335                test_entry,
2336                &opts,
2337                &mut test_summary,
2338                &mut corrected_entries,
2339                true,
2340            )
2341            .expect("Failed to run tests");
2342
2343            // parse rates will always be different, so we need to clear out these
2344            // fields to reliably assert equality below
2345            {
2346                let test_group_1_info = &mut test_summary.parse_results.root_group[0].info;
2347                match test_group_1_info {
2348                    TestInfo::Group {
2349                        ref mut children, ..
2350                    } => clear_parse_rate(&mut children[0]),
2351                    TestInfo::ParseTest { .. } | TestInfo::AssertionTest { .. } => {
2352                        panic!("Unexpected test result");
2353                    }
2354                }
2355                let test_group_2_info = &mut test_summary.parse_results.root_group[1].info;
2356                match test_group_2_info {
2357                    TestInfo::Group {
2358                        ref mut children, ..
2359                    } => {
2360                        clear_parse_rate(&mut children[0]);
2361                        clear_parse_rate(&mut children[1]);
2362                    }
2363                    TestInfo::ParseTest { .. } | TestInfo::AssertionTest { .. } => {
2364                        panic!("Unexpected test result");
2365                    }
2366                }
2367                test_summary.parse_stats.total_duration = Duration::from_secs(0);
2368            }
2369
2370            let json_results = serde_json::to_string(&test_summary).unwrap();
2371
2372            assert_eq!(
2373                json_results,
2374                json!({
2375                  "parse_results": [
2376                    {
2377                      "name": "group1",
2378                      "children": [
2379                        {
2380                          "name": "C Test 1",
2381                          "outcome": "Passed",
2382                          "parse_rate": null,
2383                          "test_num": 1
2384                        }
2385                      ]
2386                    },
2387                    {
2388                      "name": "group2",
2389                      "children": [
2390                        {
2391                          "name": "C Test 2",
2392                          "outcome": "Passed",
2393                          "parse_rate": null,
2394                          "test_num": 2
2395                        },
2396                        {
2397                          "name": "C Test 3",
2398                          "outcome": "Failed",
2399                          "parse_rate": null,
2400                          "test_num": 3
2401                        }
2402                      ]
2403                    }
2404                  ],
2405                  "parse_failures": [
2406                    {
2407                      "name": "C Test 3",
2408                      "actual": "(translation_unit (expression_statement (number_literal)))",
2409                      "expected": "(translation_unit (expression_statement (string_literal)))",
2410                      "is_cst": false,
2411                    }
2412                  ],
2413                  "parse_stats": {
2414                    "successful_parses": 2,
2415                    "total_parses": 3,
2416                    "total_bytes": 9,
2417                    "total_duration": {
2418                      "secs": 0,
2419                      "nanos": 0,
2420                    }
2421                  },
2422                  "highlight_results": [],
2423                  "tag_results": [],
2424                  "query_results": []
2425                })
2426                .to_string()
2427            );
2428        }
2429    }
2430}