Skip to main content

grammar_kit/
lib.rs

1#![doc = include_str!("../README.md")]
2
3#[cfg(feature = "syn")]
4use proc_macro2::Span;
5use std::collections::HashSet;
6#[cfg(feature = "syn")]
7use syn::parse::discouraged::Speculative;
8#[cfg(feature = "syn")]
9use syn::parse::ParseStream;
10#[cfg(feature = "syn")]
11use syn::Result;
12
13#[cfg(feature = "testing")]
14pub mod testing;
15
16/// Generic symbol table that tracks variable definitions in nested scopes.
17#[derive(Clone, Default)]
18pub struct ScopeStack {
19    scopes: Vec<HashSet<String>>,
20}
21
22impl ScopeStack {
23    pub fn new() -> Self {
24        Self {
25            scopes: vec![HashSet::new()],
26        }
27    }
28
29    pub fn enter_scope(&mut self) {
30        self.scopes.push(HashSet::new());
31    }
32
33    pub fn exit_scope(&mut self) {
34        if self.scopes.len() > 1 {
35            self.scopes.pop();
36        }
37    }
38
39    pub fn define(&mut self, name: impl Into<String>) {
40        if let Some(scope) = self.scopes.last_mut() {
41            scope.insert(name.into());
42        }
43    }
44
45    pub fn is_defined(&self, name: &str) -> bool {
46        for scope in self.scopes.iter().rev() {
47            if scope.contains(name) {
48                return true;
49            }
50        }
51        false
52    }
53
54    pub fn scopes(&self) -> &Vec<HashSet<String>> {
55        &self.scopes
56    }
57}
58
59#[cfg(all(feature = "rt", feature = "syn"))]
60#[derive(Clone)]
61struct ErrorState {
62    err: syn::Error,
63    is_deep: bool,
64}
65
66/// Holds the state for backtracking and error reporting.
67/// This must be passed mutably through the parsing chain.
68#[cfg(feature = "rt")]
69#[derive(Clone)]
70pub struct ParseContext {
71    is_fatal: bool,
72    #[cfg(feature = "syn")]
73    best_error: Option<ErrorState>,
74    pub scopes: ScopeStack,
75    rule_stack: Vec<String>,
76    #[cfg(feature = "syn")]
77    pub last_span: Option<Span>,
78}
79
80#[cfg(feature = "rt")]
81impl ParseContext {
82    pub fn new() -> Self {
83        Self {
84            is_fatal: false,
85            #[cfg(feature = "syn")]
86            best_error: None,
87            scopes: ScopeStack::new(),
88            rule_stack: Vec::new(),
89            #[cfg(feature = "syn")]
90            last_span: None,
91        }
92    }
93
94    pub fn set_fatal(&mut self, fatal: bool) {
95        self.is_fatal = fatal;
96    }
97
98    pub fn check_fatal(&self) -> bool {
99        self.is_fatal
100    }
101
102    pub fn enter_rule(&mut self, name: &str) {
103        self.rule_stack.push(name.to_string());
104    }
105
106    pub fn exit_rule(&mut self) {
107        self.rule_stack.pop();
108    }
109
110    /// Records an error if it is "deeper" than the current best error.
111    #[cfg(feature = "syn")]
112    pub fn record_error(&mut self, err: syn::Error, start_span: Span) {
113        // Heuristic: Compare the error location to the start of the attempt.
114        let is_deep = err.span().start() != start_span.start();
115
116        // Enrich error with rule name if available
117        let err = if let Some(rule_name) = self.rule_stack.last() {
118            let msg = format!("Error in rule '{}': {}", rule_name, err);
119            syn::Error::new(err.span(), msg)
120        } else {
121            err
122        };
123
124        match &mut self.best_error {
125            None => {
126                self.best_error = Some(ErrorState { err, is_deep });
127            }
128            Some(existing) => {
129                // If new is deep and existing is shallow -> Overwrite
130                if is_deep && !existing.is_deep {
131                    self.best_error = Some(ErrorState { err, is_deep });
132                }
133            }
134        }
135    }
136
137    #[cfg(feature = "syn")]
138    pub fn take_best_error(&mut self) -> Option<syn::Error> {
139        self.best_error.take().map(|s| s.err)
140    }
141
142    // --- Span Tracking ---
143
144    #[cfg(feature = "syn")]
145    pub fn record_span(&mut self, span: Span) {
146        self.last_span = Some(span);
147    }
148
149    #[cfg(feature = "syn")]
150    pub fn check_whitespace(&self, next_span: Span) -> bool {
151        if let Some(last) = self.last_span {
152            // Check if they are NOT adjacent (end != start)
153            last.end() != next_span.start()
154        } else {
155            // No previous token? Treat as valid (start of file)
156            true
157        }
158    }
159
160    // --- Symbol Table Methods ---
161
162    pub fn enter_scope(&mut self) {
163        self.scopes.enter_scope();
164    }
165
166    pub fn exit_scope(&mut self) {
167        self.scopes.exit_scope();
168    }
169
170    pub fn define(&mut self, name: impl Into<String>) {
171        self.scopes.define(name);
172    }
173
174    pub fn is_defined(&self, name: &str) -> bool {
175        self.scopes.is_defined(name)
176    }
177
178    // --- Inspection Methods ---
179
180    pub fn scopes(&self) -> &Vec<HashSet<String>> {
181        self.scopes.scopes()
182    }
183
184    pub fn rule_stack(&self) -> &Vec<String> {
185        &self.rule_stack
186    }
187}
188
189#[cfg(feature = "rt")]
190impl Default for ParseContext {
191    fn default() -> Self {
192        Self::new()
193    }
194}
195
196/// Encapsulates a speculative parse attempt.
197/// Requires passing the ParseContext to manage error state.
198#[cfg(all(feature = "rt", feature = "syn"))]
199#[inline]
200pub fn attempt<T, F>(input: ParseStream, ctx: &mut ParseContext, parser: F) -> Result<Option<T>>
201where
202    F: FnOnce(ParseStream, &mut ParseContext) -> Result<T>,
203{
204    let was_fatal = ctx.check_fatal();
205    ctx.set_fatal(false);
206
207    // Snapshot symbol table, rule stack, and last_span
208    let scopes_snapshot = ctx.scopes.clone();
209    let rule_stack_snapshot = ctx.rule_stack.clone();
210    let last_span_snapshot = ctx.last_span;
211
212    let start_span = input.span();
213    let fork = input.fork();
214
215    // Pass ctx into the closure
216    let res = parser(&fork, ctx);
217
218    let is_now_fatal = ctx.check_fatal();
219
220    match res {
221        Ok(val) => {
222            input.advance_to(&fork);
223            ctx.set_fatal(was_fatal);
224            // We KEEP the last_span updated by the successful attempt
225            Ok(Some(val))
226        }
227        Err(e) => {
228            if is_now_fatal {
229                // Restore state
230                ctx.scopes = scopes_snapshot;
231                ctx.rule_stack = rule_stack_snapshot;
232                ctx.last_span = last_span_snapshot;
233
234                ctx.set_fatal(true);
235                Err(e)
236            } else {
237                ctx.set_fatal(was_fatal);
238                // Record error BEFORE restoring state to capture inner rule context
239                ctx.record_error(e, start_span);
240
241                // Restore state
242                ctx.scopes = scopes_snapshot;
243                ctx.rule_stack = rule_stack_snapshot;
244                ctx.last_span = last_span_snapshot;
245
246                Ok(None)
247            }
248        }
249    }
250}
251
252/// Executes a parser on a fork, returning the result but NEVER advancing the input.
253/// Restores ParseContext state (scopes, last_span) to what it was before.
254#[cfg(all(feature = "rt", feature = "syn"))]
255#[inline]
256pub fn peek<T, F>(input: ParseStream, ctx: &mut ParseContext, parser: F) -> Result<T>
257where
258    F: FnOnce(ParseStream, &mut ParseContext) -> Result<T>,
259{
260    let fork = input.fork();
261
262    // Snapshot state
263    let scopes_snapshot = ctx.scopes.clone();
264    let rule_stack_snapshot = ctx.rule_stack.clone();
265    let last_span_snapshot = ctx.last_span;
266
267    let res = parser(&fork, ctx);
268
269    // Always restore state because we are peeking (state side effects should not persist)
270    ctx.scopes = scopes_snapshot;
271    ctx.rule_stack = rule_stack_snapshot;
272    ctx.last_span = last_span_snapshot;
273
274    res
275}
276
277/// Executes a parser on a fork.
278/// If it SUCCEEDS, returns Err("unexpected match").
279/// If it FAILS, returns Ok(()).
280/// Never advances input. Restores state.
281#[cfg(all(feature = "rt", feature = "syn"))]
282#[inline]
283pub fn not_check<T, F>(input: ParseStream, ctx: &mut ParseContext, parser: F) -> Result<()>
284where
285    F: FnOnce(ParseStream, &mut ParseContext) -> Result<T>,
286{
287    let fork = input.fork();
288
289    // Snapshot state
290    let scopes_snapshot = ctx.scopes.clone();
291    let rule_stack_snapshot = ctx.rule_stack.clone();
292    let last_span_snapshot = ctx.last_span;
293
294    // Disable fatal errors for the check to allow backtracking/failure
295    let was_fatal = ctx.check_fatal();
296    ctx.set_fatal(false);
297
298    let res = parser(&fork, ctx);
299
300    // Restore fatal flag
301    ctx.set_fatal(was_fatal);
302
303    // Restore state
304    ctx.scopes = scopes_snapshot;
305    ctx.rule_stack = rule_stack_snapshot;
306    ctx.last_span = last_span_snapshot;
307
308    match res {
309        Ok(_) => Err(syn::Error::new(input.span(), "unexpected match")),
310        Err(_) => Ok(()),
311    }
312}
313
314/// Wrapper around attempt used specifically for recovery blocks.
315#[cfg(all(feature = "rt", feature = "syn"))]
316#[inline]
317pub fn attempt_recover<T, F>(
318    input: ParseStream,
319    ctx: &mut ParseContext,
320    parser: F,
321) -> Result<Option<T>>
322where
323    F: FnOnce(ParseStream, &mut ParseContext) -> Result<T>,
324{
325    let was_fatal = ctx.check_fatal();
326    ctx.set_fatal(false);
327
328    // Snapshot symbol table and rule stack
329    let scopes_snapshot = ctx.scopes.clone();
330    let rule_stack_snapshot = ctx.rule_stack.clone();
331    let last_span_snapshot = ctx.last_span;
332
333    let start_span = input.span();
334    let fork = input.fork();
335
336    let res = parser(&fork, ctx);
337
338    // Always restore fatal state, ignoring whatever happened inside.
339    ctx.set_fatal(was_fatal);
340
341    match res {
342        Ok(val) => {
343            input.advance_to(&fork);
344            // Keep last_span
345            Ok(Some(val))
346        }
347        Err(e) => {
348            // Record error BEFORE restoring state
349            ctx.record_error(e, start_span);
350
351            // Restore state
352            ctx.scopes = scopes_snapshot;
353            ctx.rule_stack = rule_stack_snapshot;
354            ctx.last_span = last_span_snapshot;
355
356            Ok(None)
357        }
358    }
359}
360
361// --- Stateless Helpers (No Context Needed) ---
362
363#[cfg(all(feature = "rt", feature = "syn"))]
364#[inline]
365pub fn parse_ident(input: ParseStream) -> Result<syn::Ident> {
366    input.parse()
367}
368
369#[cfg(all(feature = "rt", feature = "syn"))]
370#[inline]
371pub fn parse_int<T: std::str::FromStr>(input: ParseStream) -> Result<T>
372where
373    T::Err: std::fmt::Display,
374{
375    input.parse::<syn::LitInt>()?.base10_parse()
376}
377
378#[cfg(all(feature = "rt", feature = "syn"))]
379pub fn skip_until(input: ParseStream, predicate: impl Fn(ParseStream) -> bool) -> Result<()> {
380    while !input.is_empty() && !predicate(input) {
381        if input.parse::<proc_macro2::TokenTree>().is_err() {
382            break;
383        }
384    }
385    Ok(())
386}
387
388#[cfg(all(test, feature = "rt", feature = "syn"))]
389mod tests {
390    use super::*;
391
392    #[test]
393    fn test_rule_name_in_error() {
394        let mut ctx = ParseContext::new();
395        ctx.enter_rule("test_rule");
396
397        let err = syn::Error::new(Span::call_site(), "expected something");
398        ctx.record_error(err, Span::call_site());
399
400        let final_err = ctx.take_best_error().unwrap();
401        assert_eq!(
402            final_err.to_string(),
403            "Error in rule 'test_rule': expected something"
404        );
405    }
406
407    #[test]
408    fn test_nested_rule_name_in_error() {
409        let mut ctx = ParseContext::new();
410        ctx.enter_rule("outer");
411        ctx.enter_rule("inner");
412
413        let err = syn::Error::new(Span::call_site(), "fail");
414        ctx.record_error(err, Span::call_site());
415
416        let final_err = ctx.take_best_error().unwrap();
417        assert_eq!(final_err.to_string(), "Error in rule 'inner': fail");
418    }
419
420    #[test]
421    fn test_attempt_captures_rule_context() {
422        use syn::parse::Parser;
423
424        let mut ctx = ParseContext::new();
425
426        let parser = |input: ParseStream| {
427            ctx.enter_rule("outer");
428
429            // We simulate an attempt that fails.
430            // attempt returns Result<Option<T>>.
431            // If the closure returns Err, attempt records it and returns Ok(None) (if not fatal).
432            let _: Option<()> = attempt(input, &mut ctx, |_input, _ctx| {
433                Err(syn::Error::new(Span::call_site(), "parse failed"))
434            })?;
435
436            ctx.exit_rule();
437            Ok(())
438        };
439
440        // We parse an empty string. The attempt fails immediately.
441        // The outer parser returns Ok(()).
442        // But we check ctx.best_error.
443        let _ = parser.parse_str("");
444
445        let err = ctx.take_best_error().expect("Error should be recorded");
446        assert_eq!(err.to_string(), "Error in rule 'outer': parse failed");
447    }
448}