Skip to main content

grammar_kit/
testing.rs

1//! Utilities for testing parsers generated by `syn-grammar`.
2//!
3//! This module provides a fluent API for testing parsing results,
4//! asserting success/failure, and checking error messages.
5
6#[cfg(feature = "syn")]
7use std::any::Any;
8use std::fmt::{Debug, Display};
9
10// A wrapper around Result to write fluent tests.
11pub struct TestResult<T, E> {
12    inner: Result<T, E>,
13    context: Option<String>,
14    source: Option<String>,
15}
16
17impl<T: Debug, E: Display + Debug + 'static> TestResult<T, E> {
18    pub fn new(result: Result<T, E>) -> Self {
19        Self {
20            inner: result,
21            context: None,
22            source: None,
23        }
24    }
25
26    pub fn with_context(mut self, context: &str) -> Self {
27        self.context = Some(context.to_string());
28        self
29    }
30
31    pub fn with_source(mut self, source: &str) -> Self {
32        self.source = Some(source.to_string());
33        self
34    }
35
36    fn format_context(&self) -> String {
37        self.context
38            .as_ref()
39            .map(|c| format!("\nContext:  {}", c))
40            .unwrap_or_default()
41    }
42
43    fn format_err(&self, err: &E) -> String {
44        format_error_impl(err, self.source.as_deref())
45    }
46
47    // 1. Asserts success and returns the value.
48    pub fn assert_success(self) -> T {
49        let ctx = self.format_context();
50        match self.inner {
51            Ok(val) => val,
52            Err(ref e) => {
53                let msg = self.format_err(e);
54                panic!(
55                    "\n🔴 TEST FAILED (Expected Success, but got Error):{}\nMessage:  {}\nError Debug: {:?}\n", 
56                    ctx, msg, e
57                );
58            }
59        }
60    }
61
62    // 2. Asserts success AND checks the value directly.
63    // Returns a nice diff output if values do not match.
64    pub fn assert_success_is<Exp>(self, expected: Exp) -> T
65    where
66        T: PartialEq<Exp>,
67        Exp: Debug,
68    {
69        let ctx = self.format_context();
70        let val = self.assert_success();
71        if val != expected {
72            panic!(
73                "\n🔴 TEST FAILED (Value Mismatch):{}\nExpected: {:?}\nGot:      {:?}\n",
74                ctx, expected, val
75            );
76        }
77        val
78    }
79
80    // 3. Asserts success AND checks the value using a closure.
81    // Useful for complex assertions or when PartialEq is not implemented.
82    pub fn assert_success_with<F>(self, f: F) -> T
83    where
84        F: FnOnce(&T),
85    {
86        let val = self.assert_success();
87        f(&val);
88        val
89    }
90
91    // 4. Asserts success AND checks the Debug representation matches.
92    // Useful for syn types where PartialEq is often missing or complicated by Spans.
93    pub fn assert_success_debug(self, expected_debug: &str) -> T {
94        let ctx = self.format_context();
95        let val = self.assert_success();
96        let actual_debug = format!("{:?}", val);
97        if actual_debug != expected_debug {
98            panic!(
99                "\n🔴 TEST FAILED (Debug Mismatch):{}\nExpected: {:?}\nGot:      {:?}\n",
100                ctx, expected_debug, actual_debug
101            );
102        }
103        val
104    }
105
106    // 5. Asserts failure and returns the error.
107    pub fn assert_failure(self) -> E {
108        let ctx = self.format_context();
109        match self.inner {
110            Ok(val) => {
111                panic!(
112                    "\n🔴 TEST FAILED (Expected Failure, but got Success):{}\nParsed Value: {:?}\n",
113                    ctx, val
114                );
115            }
116            Err(e) => e,
117        }
118    }
119
120    // 6. Asserts failure AND checks if the message contains a specific text.
121    pub fn assert_failure_contains(self, expected_msg_part: &str) {
122        let ctx = self.format_context();
123        let source = self.source.clone();
124        let err = self.assert_failure();
125        let actual_msg = err.to_string();
126        if !actual_msg.contains(expected_msg_part) {
127            let formatted = format_error_impl(&err, source.as_deref());
128            panic!(
129                "\n🔴 TEST FAILED (Error Message Mismatch):{}\nExpected part: {:?}\nActual msg:    {:?}\nError Debug:   {:?}\nFormatted:\n{}\n", 
130                ctx, expected_msg_part, actual_msg, err, formatted
131            );
132        }
133    }
134
135    // 7. Asserts success AND checks if the string representation contains a specific substring.
136    pub fn assert_success_contains(self, expected_part: &str) -> T
137    where
138        T: Display,
139    {
140        let ctx = self.format_context();
141        let val = self.assert_success();
142        let val_str = val.to_string();
143        if !val_str.contains(expected_part) {
144            panic!(
145                "\n🔴 TEST FAILED (Content Mismatch):{}\nExpected to contain: {:?}\nGot:                 {:?}\n",
146                ctx, expected_part, val_str
147            );
148        }
149        val
150    }
151}
152
153pub trait Testable<T, E> {
154    fn test(self) -> TestResult<T, E>;
155}
156
157#[cfg(feature = "syn")]
158impl<T: Debug> Testable<T, syn::Error> for syn::Result<T> {
159    fn test(self) -> TestResult<T, syn::Error> {
160        TestResult::new(self)
161    }
162}
163
164fn format_error_impl<E: Display + Debug + 'static>(err: &E, source: Option<&str>) -> String {
165    #[cfg(feature = "syn")]
166    if let Some(src) = source {
167        if let Some(syn_err) = (err as &dyn Any).downcast_ref::<syn::Error>() {
168            return pretty_print_syn_error(syn_err, src);
169        }
170    }
171    format!("{}", err)
172}
173
174#[cfg(feature = "syn")]
175fn pretty_print_syn_error(err: &syn::Error, source: &str) -> String {
176    let start = err.span().start();
177    let end = err.span().end();
178
179    if start.line == 0 {
180        return err.to_string();
181    }
182
183    let line_idx = start.line - 1;
184    let lines: Vec<&str> = source.lines().collect();
185
186    if line_idx >= lines.len() {
187        return err.to_string();
188    }
189
190    let line = lines[line_idx];
191    let col = start.column;
192
193    // Calculate width of the underline
194    // If start and end are on the same line, width is end.column - start.column
195    // Else just highlight until end of line or 1 char.
196    let width = if start.line == end.line {
197        end.column.saturating_sub(col).max(1)
198    } else {
199        1
200    };
201
202    format!(
203        "{}\n  --> line {}:{}\n   |\n {} | {}\n   | {}{}",
204        err,
205        start.line,
206        col,
207        start.line,
208        line,
209        " ".repeat(col),
210        "^".repeat(width)
211    )
212}