Skip to main content

grammar_kit/
lib.rs

1#![doc = include_str!("../README.md")]
2
3#[cfg(feature = "syn")]
4use proc_macro2::Span;
5use std::collections::HashSet;
6#[cfg(feature = "syn")]
7use syn::ext::IdentExt;
8#[cfg(feature = "syn")]
9use syn::parse::discouraged::Speculative;
10#[cfg(feature = "syn")]
11use syn::parse::ParseStream;
12#[cfg(feature = "syn")]
13use syn::Result;
14
15#[cfg(feature = "testing")]
16pub mod testing;
17
18/// Generic symbol table that tracks variable definitions in nested scopes.
19#[derive(Clone, Default)]
20pub struct ScopeStack {
21    scopes: Vec<HashSet<String>>,
22}
23
24impl ScopeStack {
25    pub fn new() -> Self {
26        Self {
27            scopes: vec![HashSet::new()],
28        }
29    }
30
31    pub fn enter_scope(&mut self) {
32        self.scopes.push(HashSet::new());
33    }
34
35    pub fn exit_scope(&mut self) {
36        if self.scopes.len() > 1 {
37            self.scopes.pop();
38        }
39    }
40
41    pub fn define(&mut self, name: impl Into<String>) {
42        if let Some(scope) = self.scopes.last_mut() {
43            scope.insert(name.into());
44        }
45    }
46
47    pub fn is_defined(&self, name: &str) -> bool {
48        for scope in self.scopes.iter().rev() {
49            if scope.contains(name) {
50                return true;
51            }
52        }
53        false
54    }
55
56    pub fn scopes(&self) -> &Vec<HashSet<String>> {
57        &self.scopes
58    }
59}
60
61#[cfg(all(feature = "rt", feature = "syn"))]
62#[derive(Clone)]
63struct ErrorState {
64    err: syn::Error,
65    is_deep: bool,
66}
67
68/// Holds the state for backtracking and error reporting.
69/// This must be passed mutably through the parsing chain.
70#[cfg(feature = "rt")]
71#[derive(Clone)]
72pub struct ParseContext {
73    is_fatal: bool,
74    #[cfg(feature = "syn")]
75    best_error: Option<ErrorState>,
76    pub scopes: ScopeStack,
77    rule_stack: Vec<String>,
78}
79
80#[cfg(feature = "rt")]
81impl ParseContext {
82    pub fn new() -> Self {
83        Self {
84            is_fatal: false,
85            #[cfg(feature = "syn")]
86            best_error: None,
87            scopes: ScopeStack::new(),
88            rule_stack: Vec::new(),
89        }
90    }
91
92    pub fn set_fatal(&mut self, fatal: bool) {
93        self.is_fatal = fatal;
94    }
95
96    pub fn check_fatal(&self) -> bool {
97        self.is_fatal
98    }
99
100    pub fn enter_rule(&mut self, name: &str) {
101        self.rule_stack.push(name.to_string());
102    }
103
104    pub fn exit_rule(&mut self) {
105        self.rule_stack.pop();
106    }
107
108    /// Records an error if it is "deeper" than the current best error.
109    #[cfg(feature = "syn")]
110    pub fn record_error(&mut self, err: syn::Error, start_span: Span) {
111        // Heuristic: Compare the error location to the start of the attempt.
112        let is_deep = err.span().start() != start_span.start();
113
114        // Enrich error with rule name if available
115        let err = if let Some(rule_name) = self.rule_stack.last() {
116            let msg = format!("Error in rule '{}': {}", rule_name, err);
117            syn::Error::new(err.span(), msg)
118        } else {
119            err
120        };
121
122        match &mut self.best_error {
123            None => {
124                self.best_error = Some(ErrorState { err, is_deep });
125            }
126            Some(existing) => {
127                // If new is deep and existing is shallow -> Overwrite
128                if is_deep && !existing.is_deep {
129                    self.best_error = Some(ErrorState { err, is_deep });
130                }
131            }
132        }
133    }
134
135    #[cfg(feature = "syn")]
136    pub fn take_best_error(&mut self) -> Option<syn::Error> {
137        self.best_error.take().map(|s| s.err)
138    }
139
140    // --- Symbol Table Methods ---
141
142    pub fn enter_scope(&mut self) {
143        self.scopes.enter_scope();
144    }
145
146    pub fn exit_scope(&mut self) {
147        self.scopes.exit_scope();
148    }
149
150    pub fn define(&mut self, name: impl Into<String>) {
151        self.scopes.define(name);
152    }
153
154    pub fn is_defined(&self, name: &str) -> bool {
155        self.scopes.is_defined(name)
156    }
157
158    // --- Inspection Methods ---
159
160    pub fn scopes(&self) -> &Vec<HashSet<String>> {
161        self.scopes.scopes()
162    }
163
164    pub fn rule_stack(&self) -> &Vec<String> {
165        &self.rule_stack
166    }
167}
168
169#[cfg(feature = "rt")]
170impl Default for ParseContext {
171    fn default() -> Self {
172        Self::new()
173    }
174}
175
176/// Encapsulates a speculative parse attempt.
177/// Requires passing the ParseContext to manage error state.
178#[cfg(all(feature = "rt", feature = "syn"))]
179#[inline]
180pub fn attempt<T, F>(input: ParseStream, ctx: &mut ParseContext, parser: F) -> Result<Option<T>>
181where
182    F: FnOnce(ParseStream, &mut ParseContext) -> Result<T>,
183{
184    let was_fatal = ctx.check_fatal();
185    ctx.set_fatal(false);
186
187    // Snapshot symbol table and rule stack
188    let scopes_snapshot = ctx.scopes.clone();
189    let rule_stack_snapshot = ctx.rule_stack.clone();
190
191    let start_span = input.span();
192    let fork = input.fork();
193
194    // Pass ctx into the closure
195    let res = parser(&fork, ctx);
196
197    let is_now_fatal = ctx.check_fatal();
198
199    match res {
200        Ok(val) => {
201            input.advance_to(&fork);
202            ctx.set_fatal(was_fatal);
203            Ok(Some(val))
204        }
205        Err(e) => {
206            if is_now_fatal {
207                // Restore state
208                ctx.scopes = scopes_snapshot;
209                ctx.rule_stack = rule_stack_snapshot;
210
211                ctx.set_fatal(true);
212                Err(e)
213            } else {
214                ctx.set_fatal(was_fatal);
215                // Record error BEFORE restoring state to capture inner rule context
216                ctx.record_error(e, start_span);
217
218                // Restore state
219                ctx.scopes = scopes_snapshot;
220                ctx.rule_stack = rule_stack_snapshot;
221
222                Ok(None)
223            }
224        }
225    }
226}
227
228/// Wrapper around attempt used specifically for recovery blocks.
229#[cfg(all(feature = "rt", feature = "syn"))]
230#[inline]
231pub fn attempt_recover<T, F>(
232    input: ParseStream,
233    ctx: &mut ParseContext,
234    parser: F,
235) -> Result<Option<T>>
236where
237    F: FnOnce(ParseStream, &mut ParseContext) -> Result<T>,
238{
239    let was_fatal = ctx.check_fatal();
240    ctx.set_fatal(false);
241
242    // Snapshot symbol table and rule stack
243    let scopes_snapshot = ctx.scopes.clone();
244    let rule_stack_snapshot = ctx.rule_stack.clone();
245
246    let start_span = input.span();
247    let fork = input.fork();
248
249    let res = parser(&fork, ctx);
250
251    // Always restore fatal state, ignoring whatever happened inside.
252    ctx.set_fatal(was_fatal);
253
254    match res {
255        Ok(val) => {
256            input.advance_to(&fork);
257            Ok(Some(val))
258        }
259        Err(e) => {
260            // Record error BEFORE restoring state
261            ctx.record_error(e, start_span);
262
263            // Restore state
264            ctx.scopes = scopes_snapshot;
265            ctx.rule_stack = rule_stack_snapshot;
266
267            Ok(None)
268        }
269    }
270}
271
272// --- Stateless Helpers (No Context Needed) ---
273
274#[cfg(all(feature = "rt", feature = "syn"))]
275#[inline]
276pub fn parse_ident(input: ParseStream) -> Result<syn::Ident> {
277    input.call(syn::Ident::parse_any)
278}
279
280#[cfg(all(feature = "rt", feature = "syn"))]
281#[inline]
282pub fn parse_int<T: std::str::FromStr>(input: ParseStream) -> Result<T>
283where
284    T::Err: std::fmt::Display,
285{
286    input.parse::<syn::LitInt>()?.base10_parse()
287}
288
289#[cfg(all(feature = "rt", feature = "syn"))]
290pub fn skip_until(input: ParseStream, predicate: impl Fn(ParseStream) -> bool) -> Result<()> {
291    while !input.is_empty() && !predicate(input) {
292        if input.parse::<proc_macro2::TokenTree>().is_err() {
293            break;
294        }
295    }
296    Ok(())
297}
298
299#[cfg(all(test, feature = "rt", feature = "syn"))]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_rule_name_in_error() {
305        let mut ctx = ParseContext::new();
306        ctx.enter_rule("test_rule");
307
308        let err = syn::Error::new(Span::call_site(), "expected something");
309        ctx.record_error(err, Span::call_site());
310
311        let final_err = ctx.take_best_error().unwrap();
312        assert_eq!(
313            final_err.to_string(),
314            "Error in rule 'test_rule': expected something"
315        );
316    }
317
318    #[test]
319    fn test_nested_rule_name_in_error() {
320        let mut ctx = ParseContext::new();
321        ctx.enter_rule("outer");
322        ctx.enter_rule("inner");
323
324        let err = syn::Error::new(Span::call_site(), "fail");
325        ctx.record_error(err, Span::call_site());
326
327        let final_err = ctx.take_best_error().unwrap();
328        assert_eq!(final_err.to_string(), "Error in rule 'inner': fail");
329    }
330
331    #[test]
332    fn test_attempt_captures_rule_context() {
333        use syn::parse::Parser;
334
335        let mut ctx = ParseContext::new();
336
337        let parser = |input: ParseStream| {
338            ctx.enter_rule("outer");
339
340            // We simulate an attempt that fails.
341            // attempt returns Result<Option<T>>.
342            // If the closure returns Err, attempt records it and returns Ok(None) (if not fatal).
343            let _ = attempt(input, &mut ctx, |_input, _ctx| {
344                Err(syn::Error::new(Span::call_site(), "parse failed"))
345            })?;
346
347            ctx.exit_rule();
348            Ok(())
349        };
350
351        // We parse an empty string. The attempt fails immediately.
352        // The outer parser returns Ok(()).
353        // But we check ctx.best_error.
354        let _ = parser.parse_str("");
355
356        let err = ctx.take_best_error().expect("Error should be recorded");
357        assert_eq!(err.to_string(), "Error in rule 'outer': parse failed");
358    }
359}