Skip to main content

relon_parser/
ast.rs

1//! Typed-AST wrappers over the rowan CST.
2//!
3//! P3 of the rowan rewrite. Each wrapper is a transparent tuple struct
4//! around a `SyntaxNode` (or `SyntaxToken`) — there is no extra
5//! allocation, and a wrapper can be obtained from a CST node in
6//! O(1) via `cast(node)` (which returns `None` when the kind doesn't
7//! match).
8//!
9//! The wrapper API mirrors the existing token-tree types in
10//! [`crate::token`] (`Node`, `Expr`, `TokenKey`, `Directive`,
11//! `Decorator`, `TypeNode`, etc.) as closely as possible. P4 will
12//! migrate downstream crates (analyzer, evaluator, fmt, wasm, lsp)
13//! by swapping their winnow-based parsing for these wrappers; the
14//! parallel API shape keeps that migration mostly mechanical.
15//!
16//! ## Variant kinds
17//!
18//! `Expr` is the central typed enum — one variant per [`crate::Expr`]
19//! variant plus a new [`Expr::Error`] for spans the CST couldn't fit
20//! into any production. Each variant carries a structured wrapper
21//! that exposes the relevant children:
22//!
23//! * `Expr::Dict(Dict)` — has `.fields()` iterator.
24//! * `Expr::List(List)` / `Expr::Comprehension(Comprehension)`.
25//! * `Expr::Binary(BinaryExpr)` — `.op_kind()` + `.lhs()` + `.rhs()`.
26//! * `Expr::Call(CallExpr)` — `.callee()` + `.args()`.
27//! * `Expr::Closure(Closure)` — `.params()` + `.body()`.
28//! * ...etc, see the impls below.
29//!
30//! When the underlying `SyntaxKind` is `ERROR` (or any unrecognised
31//! node), `Expr::cast` returns `Expr::Error(ErrorNode)`. Downstream
32//! callers must add a no-op arm for this in match statements.
33
34use crate::syntax::{SyntaxKind, SyntaxNode};
35
36// =====================================================================
37// Macro: define a typed wrapper around a SyntaxNode of one specific
38// kind. Generates the standard `cast` / `syntax` / `text` boilerplate.
39// =====================================================================
40
41macro_rules! ast_node {
42    ($(#[$meta:meta])* $name:ident, $kind:ident) => {
43        $(#[$meta])*
44        #[derive(Debug, Clone, PartialEq, Eq, Hash)]
45        pub struct $name(SyntaxNode);
46
47        impl $name {
48            /// Wrap `node` if its `SyntaxKind` matches; otherwise
49            /// return `None`. O(1) — just a kind check.
50            pub fn cast(node: SyntaxNode) -> Option<Self> {
51                if node.kind() == SyntaxKind::$kind {
52                    Some(Self(node))
53                } else {
54                    None
55                }
56            }
57
58            /// Borrow the underlying [`SyntaxNode`]. Useful for
59            /// downstream traversals that want CST-level access.
60            pub fn syntax(&self) -> &SyntaxNode {
61                &self.0
62            }
63
64            /// Verbatim source text spanned by this node, including
65            /// trivia. Cheap on rowan (it walks the green tree).
66            pub fn text(&self) -> String {
67                self.0.text().to_string()
68            }
69        }
70    };
71}
72
73ast_node!(
74    /// `DOCUMENT` root — every parse produces exactly one. Carries
75    /// the leading directives + the root expression body.
76    Document, DOCUMENT
77);
78
79ast_node!(
80    /// `#name <body>` form. Body shape depends on the directive's
81    /// name; the typed-AST layer reads it from [`Directive::name`].
82    Directive, DIRECTIVE
83);
84
85ast_node!(
86    /// `@name(args?)` decorator.
87    Decorator, DECORATOR
88);
89
90ast_node!(Dict, DICT);
91ast_node!(DictField, DICT_FIELD);
92ast_node!(List, LIST);
93ast_node!(
94    /// `(e1, e2, ...)` tuple value literal. `()` is the unit tuple
95    /// (no children); `(e,)` the trailing-comma 1-tuple. A bare
96    /// grouping `(e)` is NOT a tuple and never produces this node.
97    Tuple,
98    TUPLE
99);
100ast_node!(Comprehension, COMPREHENSION);
101ast_node!(Closure, CLOSURE);
102ast_node!(ClosureParam, CLOSURE_PARAM);
103ast_node!(CallExpr, CALL_EXPR);
104ast_node!(CallArg, CALL_ARG);
105ast_node!(BinaryExpr, BINARY_EXPR);
106ast_node!(UnaryExpr, UNARY_EXPR);
107ast_node!(TernaryExpr, TERNARY_EXPR);
108ast_node!(ReferenceExpr, REFERENCE_EXPR);
109ast_node!(VariableExpr, VARIABLE_EXPR);
110ast_node!(WhereExpr, WHERE_EXPR);
111ast_node!(MatchExpr, MATCH_EXPR);
112ast_node!(MatchArm, MATCH_ARM);
113ast_node!(VariantCtor, VARIANT_CTOR);
114ast_node!(FString, F_STRING);
115ast_node!(FStringInterpolation, F_STRING_INTERPOLATION);
116ast_node!(SpreadExpr, SPREAD_EXPR);
117ast_node!(TypeNode, TYPE_NODE);
118ast_node!(
119    /// `(T1, T2, ...)` tuple type. Sits inside a TypeNode-shaped
120    /// position (typed dict field, generic argument list, closure
121    /// parameter, schema-method param).
122    TupleType,
123    TUPLE_TYPE
124);
125ast_node!(
126    /// `with { ... }` body of a `#schema` / `#extend` directive.
127    /// Children: pragma directives + zero or more [`SchemaMethod`].
128    SchemaWith,
129    SCHEMA_WITH
130);
131ast_node!(
132    /// One method declaration inside a [`SchemaWith`] block.
133    /// Children: leading pragma directives, the method name token,
134    /// optional generic params, a closure-param list, return type,
135    /// and an optional body expression.
136    SchemaMethod,
137    SCHEMA_METHOD
138);
139ast_node!(Wildcard, WILDCARD);
140ast_node!(Literal, LITERAL);
141ast_node!(ErrorNode, ERROR);
142
143// =====================================================================
144// `Expr` — the top-level typed enum. Mirrors `crate::Expr` plus a new
145// `Error` variant for partial parses.
146// =====================================================================
147
148/// Typed view over any expression-shaped CST node. Returned by
149/// [`Expr::cast`].
150///
151/// Note the variant naming follows the CST kinds, not the legacy
152/// [`crate::Expr`] enum — `Literal` covers `true` / `false` / numeric / string atoms uniformly,
153/// plus the removed `null` spelling for diagnostics. The legacy enum split
154/// them into `Bool` / `Int` / `Float` / `String` plus internal `Missing`. Downstream
155/// callers that need the specific literal type read it from
156/// [`Literal::kind`] / [`Literal::value_text`].
157#[derive(Debug, Clone, PartialEq, Eq, Hash)]
158pub enum Expr {
159    Literal(Literal),
160    Variable(VariableExpr),
161    Reference(ReferenceExpr),
162    Dict(Dict),
163    List(List),
164    Tuple(Tuple),
165    Spread(SpreadExpr),
166    Comprehension(Comprehension),
167    Binary(BinaryExpr),
168    Unary(UnaryExpr),
169    Ternary(TernaryExpr),
170    Call(CallExpr),
171    FString(FString),
172    Type(TypeNode),
173    Wildcard(Wildcard),
174    Where(WhereExpr),
175    Match(MatchExpr),
176    Closure(Closure),
177    VariantCtor(VariantCtor),
178    /// New variant introduced by the CST rewrite — spans bytes the
179    /// parser couldn't fit into any production. Downstream callers
180    /// must handle it (typically by skipping or surfacing a
181    /// diagnostic).
182    Error(ErrorNode),
183}
184
185impl Expr {
186    /// Wrap `node` if it names an expression-shaped CST kind.
187    /// Returns `None` for non-expression nodes (DICT_FIELD,
188    /// CLOSURE_PARAM, etc.) — those have their own typed wrappers.
189    pub fn cast(node: SyntaxNode) -> Option<Self> {
190        Some(match node.kind() {
191            SyntaxKind::LITERAL => Self::Literal(Literal(node)),
192            SyntaxKind::VARIABLE_EXPR => Self::Variable(VariableExpr(node)),
193            SyntaxKind::REFERENCE_EXPR => Self::Reference(ReferenceExpr(node)),
194            SyntaxKind::DICT => Self::Dict(Dict(node)),
195            SyntaxKind::LIST => Self::List(List(node)),
196            SyntaxKind::TUPLE => Self::Tuple(Tuple(node)),
197            SyntaxKind::SPREAD_EXPR => Self::Spread(SpreadExpr(node)),
198            SyntaxKind::COMPREHENSION => Self::Comprehension(Comprehension(node)),
199            SyntaxKind::BINARY_EXPR => Self::Binary(BinaryExpr(node)),
200            SyntaxKind::UNARY_EXPR => Self::Unary(UnaryExpr(node)),
201            SyntaxKind::TERNARY_EXPR => Self::Ternary(TernaryExpr(node)),
202            SyntaxKind::CALL_EXPR => Self::Call(CallExpr(node)),
203            SyntaxKind::F_STRING => Self::FString(FString(node)),
204            SyntaxKind::TYPE_NODE => Self::Type(TypeNode(node)),
205            SyntaxKind::WILDCARD => Self::Wildcard(Wildcard(node)),
206            SyntaxKind::WHERE_EXPR => Self::Where(WhereExpr(node)),
207            SyntaxKind::MATCH_EXPR => Self::Match(MatchExpr(node)),
208            SyntaxKind::CLOSURE => Self::Closure(Closure(node)),
209            SyntaxKind::VARIANT_CTOR => Self::VariantCtor(VariantCtor(node)),
210            SyntaxKind::ERROR => Self::Error(ErrorNode(node)),
211            _ => return None,
212        })
213    }
214
215    /// Borrow the underlying [`SyntaxNode`] regardless of variant.
216    pub fn syntax(&self) -> &SyntaxNode {
217        match self {
218            Self::Literal(n) => n.syntax(),
219            Self::Variable(n) => n.syntax(),
220            Self::Reference(n) => n.syntax(),
221            Self::Dict(n) => n.syntax(),
222            Self::List(n) => n.syntax(),
223            Self::Tuple(n) => n.syntax(),
224            Self::Spread(n) => n.syntax(),
225            Self::Comprehension(n) => n.syntax(),
226            Self::Binary(n) => n.syntax(),
227            Self::Unary(n) => n.syntax(),
228            Self::Ternary(n) => n.syntax(),
229            Self::Call(n) => n.syntax(),
230            Self::FString(n) => n.syntax(),
231            Self::Type(n) => n.syntax(),
232            Self::Wildcard(n) => n.syntax(),
233            Self::Where(n) => n.syntax(),
234            Self::Match(n) => n.syntax(),
235            Self::Closure(n) => n.syntax(),
236            Self::VariantCtor(n) => n.syntax(),
237            Self::Error(n) => n.syntax(),
238        }
239    }
240
241    /// Verbatim source text. Convenience over `self.syntax().text()`.
242    pub fn text(&self) -> String {
243        self.syntax().text().to_string()
244    }
245}
246
247// =====================================================================
248// Per-node accessors. Each wrapper exposes the structural data the
249// downstream typed-AST callers will need, mirroring the existing
250// `Node` / `Expr` API in `token.rs`.
251// =====================================================================
252
253impl Document {
254    /// All directives stacked above the root value, in source order.
255    pub fn directives(&self) -> impl Iterator<Item = Directive> + '_ {
256        self.0.children().filter_map(Directive::cast)
257    }
258
259    /// All decorators stacked above the root value, in source order.
260    pub fn decorators(&self) -> impl Iterator<Item = Decorator> + '_ {
261        self.0.children().filter_map(Decorator::cast)
262    }
263
264    /// The root expression, if the file has one. Files containing
265    /// only directives (e.g. a `#schema` library) have `None`.
266    pub fn root_expr(&self) -> Option<Expr> {
267        self.0.children().find_map(Expr::cast)
268    }
269}
270
271impl Directive {
272    /// Directive name (everything after the `#`). `None` when the
273    /// parser emitted an ERROR before the name was captured.
274    pub fn name(&self) -> Option<String> {
275        // The first IDENT token under the DIRECTIVE node is the name
276        // (the `#` itself is the leading leaf).
277        self.0
278            .children_with_tokens()
279            .filter_map(|el| el.into_token())
280            .find(|t| t.kind() == SyntaxKind::IDENT)
281            .map(|t| t.text().to_string())
282    }
283
284    /// Direct-child expression(s) of the directive body. For
285    /// `#schema X { ... }` this yields the body dict; for
286    /// `#default 0` it yields the value expression. The number of
287    /// items is shape-dependent and the typed-AST layer above this
288    /// crate decides interpretation.
289    pub fn body_exprs(&self) -> impl Iterator<Item = Expr> + '_ {
290        self.0.children().filter_map(Expr::cast)
291    }
292}
293
294impl Decorator {
295    pub fn name(&self) -> Option<String> {
296        self.0
297            .children_with_tokens()
298            .filter_map(|el| el.into_token())
299            .find(|t| t.kind() == SyntaxKind::IDENT)
300            .map(|t| t.text().to_string())
301    }
302
303    pub fn args(&self) -> impl Iterator<Item = Expr> + '_ {
304        // CALL_ARG is the child node holding the parens; the args
305        // inside it are the actual expressions.
306        self.0
307            .children()
308            .find(|c| c.kind() == SyntaxKind::CALL_ARG)
309            .into_iter()
310            .flat_map(|n| n.children().filter_map(Expr::cast).collect::<Vec<_>>())
311    }
312}
313
314impl Dict {
315    pub fn fields(&self) -> impl Iterator<Item = DictField> + '_ {
316        self.0.children().filter_map(DictField::cast)
317    }
318}
319
320impl DictField {
321    /// Key text — the bare identifier or string-literal key.
322    /// Returns `None` for spread / dynamic-keyed fields (callers
323    /// should inspect the children directly for those shapes).
324    pub fn key_text(&self) -> Option<String> {
325        self.0
326            .children_with_tokens()
327            .filter_map(|el| el.into_token())
328            .find(|t| t.kind() == SyntaxKind::IDENT || t.kind() == SyntaxKind::STRING)
329            .map(|t| t.text().to_string())
330    }
331
332    /// The value expression. For method-shorthand closure fields
333    /// (`key(params): body`) this is the closure; otherwise it's
334    /// whatever follows the `:`.
335    pub fn value(&self) -> Option<Expr> {
336        self.0.children().filter_map(Expr::cast).next()
337    }
338}
339
340impl List {
341    pub fn items(&self) -> impl Iterator<Item = Expr> + '_ {
342        self.0.children().filter_map(Expr::cast)
343    }
344}
345
346impl Tuple {
347    /// Element expressions in source order. Empty for the unit tuple `()`.
348    pub fn items(&self) -> impl Iterator<Item = Expr> + '_ {
349        self.0.children().filter_map(Expr::cast)
350    }
351}
352
353impl Comprehension {
354    /// `[ element for id in iterable (if cond)? ]`. Returns the
355    /// inner expressions in their structural roles. Falls back on
356    /// CST-order when a malformed comprehension drops one of them.
357    pub fn parts(&self) -> Vec<Expr> {
358        self.0.children().filter_map(Expr::cast).collect()
359    }
360
361    /// The bound identifier between `for` and `in`. `None` on a
362    /// malformed comprehension.
363    pub fn binding(&self) -> Option<String> {
364        let mut after_for = false;
365        for el in self.0.children_with_tokens() {
366            if let Some(t) = el.as_token() {
367                if t.kind() == SyntaxKind::IDENT {
368                    let s = t.text();
369                    if after_for {
370                        return Some(s.to_string());
371                    }
372                    if s == "for" {
373                        after_for = true;
374                    }
375                }
376            }
377        }
378        None
379    }
380}
381
382impl Closure {
383    pub fn params(&self) -> impl Iterator<Item = ClosureParam> + '_ {
384        self.0.children().filter_map(ClosureParam::cast)
385    }
386
387    /// Optional return type — `-> Type`. The TYPE_NODE child that
388    /// follows the `->` arrow.
389    pub fn return_type(&self) -> Option<TypeNode> {
390        let mut saw_arrow = false;
391        for el in self.0.children_with_tokens() {
392            if let Some(t) = el.as_token() {
393                if t.kind() == SyntaxKind::THIN_ARROW {
394                    saw_arrow = true;
395                }
396            } else if let Some(n) = el.as_node() {
397                if saw_arrow && n.kind() == SyntaxKind::TYPE_NODE {
398                    return TypeNode::cast(n.clone());
399                }
400            }
401        }
402        None
403    }
404
405    /// The body expression — everything after `=>` (or after `:`
406    /// for the dict-field method shorthand).
407    pub fn body(&self) -> Option<Expr> {
408        // The body is the LAST expression child — both the typed
409        // params (which contain their own TYPE_NODEs) and the
410        // return type sit before it in source order. Filter out
411        // the return TYPE_NODE if it exists.
412        let mut last: Option<Expr> = None;
413        for child in self.0.children() {
414            if child.kind() == SyntaxKind::CLOSURE_PARAM || child.kind() == SyntaxKind::TYPE_NODE {
415                continue;
416            }
417            if let Some(e) = Expr::cast(child) {
418                last = Some(e);
419            }
420        }
421        last
422    }
423}
424
425impl ClosureParam {
426    pub fn name(&self) -> Option<String> {
427        // The non-type IDENT is the parameter name. Skip any
428        // TYPE_NODE child and pick the trailing IDENT token.
429        self.0
430            .children_with_tokens()
431            .filter_map(|el| el.into_token())
432            .filter(|t| t.kind() == SyntaxKind::IDENT)
433            .last()
434            .map(|t| t.text().to_string())
435    }
436
437    pub fn type_hint(&self) -> Option<TypeNode> {
438        self.0.children().find_map(TypeNode::cast)
439    }
440}
441
442impl CallExpr {
443    /// The callee expression (the thing being called). It's the
444    /// first expression child — typically a VARIABLE_EXPR but in
445    /// principle any postfix-able expression.
446    pub fn callee(&self) -> Option<Expr> {
447        self.0.children().find_map(Expr::cast)
448    }
449
450    /// Arguments inside the parens.
451    pub fn args(&self) -> impl Iterator<Item = Expr> + '_ {
452        self.0
453            .children()
454            .find(|c| c.kind() == SyntaxKind::CALL_ARG)
455            .into_iter()
456            .flat_map(|n| n.children().filter_map(Expr::cast).collect::<Vec<_>>())
457    }
458}
459
460impl BinaryExpr {
461    /// Return the operator token's `SyntaxKind` (e.g.
462    /// `SyntaxKind::PLUS`, `SyntaxKind::EQ_EQ`). `None` only on a
463    /// malformed parse where the operator token is missing.
464    pub fn op_kind(&self) -> Option<SyntaxKind> {
465        self.0
466            .children_with_tokens()
467            .filter_map(|el| el.into_token())
468            .map(|t| t.kind())
469            .find(|k| {
470                matches!(
471                    k,
472                    SyntaxKind::PLUS
473                        | SyntaxKind::MINUS
474                        | SyntaxKind::STAR
475                        | SyntaxKind::SLASH
476                        | SyntaxKind::PERCENT
477                        | SyntaxKind::PLUS_PLUS
478                        | SyntaxKind::EQ_EQ
479                        | SyntaxKind::BANG_EQ
480                        | SyntaxKind::LT
481                        | SyntaxKind::GT
482                        | SyntaxKind::LT_EQ
483                        | SyntaxKind::GT_EQ
484                        | SyntaxKind::AMP_AMP
485                        | SyntaxKind::PIPE_PIPE
486                        | SyntaxKind::PIPE
487                )
488            })
489    }
490
491    pub fn lhs(&self) -> Option<Expr> {
492        self.0.children().find_map(Expr::cast)
493    }
494
495    pub fn rhs(&self) -> Option<Expr> {
496        self.0.children().filter_map(Expr::cast).nth(1)
497    }
498}
499
500impl UnaryExpr {
501    /// Operator token kind (`MINUS` / `BANG` / `PLUS`).
502    pub fn op_kind(&self) -> Option<SyntaxKind> {
503        self.0
504            .children_with_tokens()
505            .filter_map(|el| el.into_token())
506            .map(|t| t.kind())
507            .find(|k| matches!(k, SyntaxKind::MINUS | SyntaxKind::BANG | SyntaxKind::PLUS))
508    }
509
510    pub fn operand(&self) -> Option<Expr> {
511        self.0.children().find_map(Expr::cast)
512    }
513}
514
515impl TernaryExpr {
516    pub fn cond(&self) -> Option<Expr> {
517        self.0.children().find_map(Expr::cast)
518    }
519
520    pub fn then(&self) -> Option<Expr> {
521        self.0.children().filter_map(Expr::cast).nth(1)
522    }
523
524    pub fn els(&self) -> Option<Expr> {
525        self.0.children().filter_map(Expr::cast).nth(2)
526    }
527}
528
529impl ReferenceExpr {
530    /// Reference base identifier (`root`, `sibling`, `uncle`,
531    /// `this`, `prev`, `next`, `index`). The CST keeps the bare
532    /// IDENT token directly under the node.
533    pub fn base_name(&self) -> Option<String> {
534        self.0
535            .children_with_tokens()
536            .filter_map(|el| el.into_token())
537            .find(|t| t.kind() == SyntaxKind::IDENT)
538            .map(|t| t.text().to_string())
539    }
540
541    /// Whole `&base.x.y` text. Cheap fallback when callers don't
542    /// need to inspect each path segment individually.
543    pub fn path_text(&self) -> String {
544        self.text()
545    }
546}
547
548impl VariableExpr {
549    /// Every IDENT-shaped path segment in source order.
550    pub fn segments(&self) -> Vec<String> {
551        self.0
552            .children_with_tokens()
553            .filter_map(|el| el.into_token())
554            .filter(|t| t.kind() == SyntaxKind::IDENT)
555            .map(|t| t.text().to_string())
556            .collect()
557    }
558}
559
560impl Literal {
561    /// Kind of the underlying literal token. Useful for the
562    /// `true`/`false`/NUMBER/STRING dispatch downstream
563    /// callers need to type-check.
564    pub fn kind(&self) -> Option<SyntaxKind> {
565        self.0
566            .children_with_tokens()
567            .filter_map(|el| el.into_token())
568            .map(|t| t.kind())
569            .find(|k| {
570                matches!(
571                    k,
572                    SyntaxKind::NUMBER | SyntaxKind::STRING | SyntaxKind::IDENT
573                )
574            })
575    }
576
577    /// Verbatim text of the literal token (e.g. `"42"`, `r#""hi""#`,
578    /// `"true"`).
579    pub fn value_text(&self) -> String {
580        self.text()
581    }
582}
583
584impl WhereExpr {
585    /// The leading expression (everything before `where`).
586    pub fn expr(&self) -> Option<Expr> {
587        self.0.children().find_map(Expr::cast)
588    }
589
590    /// The binding dict that follows `where`.
591    pub fn bindings(&self) -> Option<Dict> {
592        self.0.children().filter_map(Dict::cast).next()
593    }
594}
595
596impl MatchExpr {
597    /// The scrutinee (everything before `match`).
598    pub fn scrutinee(&self) -> Option<Expr> {
599        self.0.children().find_map(Expr::cast)
600    }
601
602    pub fn arms(&self) -> impl Iterator<Item = MatchArm> + '_ {
603        self.0.children().filter_map(MatchArm::cast)
604    }
605}
606
607impl MatchArm {
608    /// Pattern — typically a TYPE_NODE; `*` wildcards parse as
609    /// a [`Wildcard`] child.
610    pub fn pattern(&self) -> Option<Expr> {
611        self.0.children().find_map(Expr::cast)
612    }
613
614    /// Arm body (everything after `:`).
615    pub fn body(&self) -> Option<Expr> {
616        self.0.children().filter_map(Expr::cast).nth(1)
617    }
618}
619
620impl SpreadExpr {
621    /// The inner expression being spread.
622    pub fn inner(&self) -> Option<Expr> {
623        self.0.children().find_map(Expr::cast)
624    }
625}
626
627impl VariantCtor {
628    /// Body dict literal `Enum.Variant { ... }`.
629    pub fn body(&self) -> Option<Dict> {
630        self.0.children().find_map(Dict::cast)
631    }
632}
633
634impl FString {
635    /// Iterator over the f-string's literal text chunks and
636    /// interpolation sub-nodes, in source order.
637    pub fn parts(&self) -> Vec<FStringPart> {
638        let mut out = Vec::new();
639        for el in self.0.children_with_tokens() {
640            if let Some(t) = el.as_token() {
641                if t.kind() == SyntaxKind::F_STRING_LITERAL {
642                    out.push(FStringPart::Literal(t.text().to_string()));
643                }
644            } else if let Some(n) = el.as_node() {
645                if let Some(interp) = FStringInterpolation::cast(n.clone()) {
646                    out.push(FStringPart::Interpolation(interp));
647                }
648            }
649        }
650        out
651    }
652}
653
654impl FStringInterpolation {
655    /// The inner expression — what gets evaluated and formatted in.
656    pub fn expr(&self) -> Option<Expr> {
657        self.0.children().find_map(Expr::cast)
658    }
659}
660
661/// View of one piece of an [`FString`]. Mirrors `crate::FStringPart`
662/// at the rowan side.
663#[derive(Debug, Clone, PartialEq, Eq, Hash)]
664pub enum FStringPart {
665    Literal(String),
666    Interpolation(FStringInterpolation),
667}
668
669impl TypeNode {
670    /// Path segments — `Foo` / `Foo.Bar` / `"foo".Bar`. Returns the
671    /// raw text of each IDENT / STRING token preceding the first
672    /// generic / `?`.
673    pub fn path_text(&self) -> Vec<String> {
674        let mut out = Vec::new();
675        for el in self.0.children_with_tokens() {
676            if let Some(t) = el.as_token() {
677                match t.kind() {
678                    SyntaxKind::LT => break,
679                    SyntaxKind::QUESTION => break,
680                    SyntaxKind::DOT => continue,
681                    SyntaxKind::IDENT | SyntaxKind::STRING => out.push(t.text().to_string()),
682                    _ => {}
683                }
684            } else {
685                break;
686            }
687        }
688        out
689    }
690
691    /// Direct-child TYPE_NODEs nested inside this one's generic
692    /// argument list.
693    pub fn generics(&self) -> impl Iterator<Item = TypeNode> + '_ {
694        self.0.children().filter_map(TypeNode::cast)
695    }
696
697    /// `Foo?` — true when the trailing `?` is present.
698    pub fn is_optional(&self) -> bool {
699        self.0
700            .children_with_tokens()
701            .filter_map(|el| el.into_token())
702            .any(|t| t.kind() == SyntaxKind::QUESTION)
703    }
704}
705
706// =====================================================================
707// Convenience entry points.
708// =====================================================================
709
710/// Cast the root of a [`crate::cst::Parse`] result to a typed
711/// [`Document`]. The root kind is always `DOCUMENT` so the call
712/// never returns `None`; the `Option` is for API uniformity with
713/// the other `cast` entries.
714pub fn document_of(syntax: SyntaxNode) -> Option<Document> {
715    Document::cast(syntax)
716}
717
718/// Re-export of [`crate::syntax::SyntaxToken`] for callers who need it but don't
719/// otherwise depend on the `syntax` module.
720pub use crate::syntax::SyntaxToken as _Token;
721
722#[cfg(test)]
723mod tests {
724    use super::*;
725    use crate::cst::parse_cst;
726
727    #[test]
728    fn document_round_trip() {
729        let p = parse_cst("{ a: 1, b: 2 }");
730        let doc = Document::cast(p.syntax()).expect("DOCUMENT kind");
731        assert!(doc.root_expr().is_some());
732    }
733
734    #[test]
735    fn dict_fields() {
736        let p = parse_cst("{ alice: 1, bob: 2 }");
737        let doc = Document::cast(p.syntax()).unwrap();
738        let dict = match doc.root_expr().unwrap() {
739            Expr::Dict(d) => d,
740            _ => panic!(),
741        };
742        let keys: Vec<_> = dict.fields().filter_map(|f| f.key_text()).collect();
743        assert_eq!(keys, vec!["alice".to_string(), "bob".to_string()]);
744    }
745
746    #[test]
747    fn binary_op_kind() {
748        let p = parse_cst("{ x: 1 + 2 }");
749        let doc = Document::cast(p.syntax()).unwrap();
750        let dict = match doc.root_expr().unwrap() {
751            Expr::Dict(d) => d,
752            _ => panic!(),
753        };
754        let value = dict.fields().next().and_then(|f| f.value()).unwrap();
755        let bin = match value {
756            Expr::Binary(b) => b,
757            other => panic!("not binary: {other:?}"),
758        };
759        assert_eq!(bin.op_kind(), Some(SyntaxKind::PLUS));
760        assert!(bin.lhs().is_some());
761        assert!(bin.rhs().is_some());
762    }
763
764    #[test]
765    fn closure_typed_params() {
766        let p = parse_cst("{ add(Int a, Int b): a + b }");
767        let doc = Document::cast(p.syntax()).unwrap();
768        let dict = match doc.root_expr().unwrap() {
769            Expr::Dict(d) => d,
770            _ => panic!(),
771        };
772        let cls = match dict.fields().next().and_then(|f| f.value()).unwrap() {
773            Expr::Closure(c) => c,
774            other => panic!("not closure: {other:?}"),
775        };
776        let params: Vec<_> = cls.params().collect();
777        assert_eq!(params.len(), 2);
778        assert_eq!(params[0].name().as_deref(), Some("a"));
779        assert!(params[0].type_hint().is_some());
780    }
781
782    #[test]
783    fn f_string_parts() {
784        let p = parse_cst(r#"{ msg: f"hi ${name}!" }"#);
785        let doc = Document::cast(p.syntax()).unwrap();
786        let dict = match doc.root_expr().unwrap() {
787            Expr::Dict(d) => d,
788            _ => panic!(),
789        };
790        let fs = match dict.fields().next().and_then(|f| f.value()).unwrap() {
791            Expr::FString(f) => f,
792            _ => panic!(),
793        };
794        let parts = fs.parts();
795        let mut has_lit = false;
796        let mut has_interp = false;
797        for p in &parts {
798            match p {
799                FStringPart::Literal(_) => has_lit = true,
800                FStringPart::Interpolation(_) => has_interp = true,
801            }
802        }
803        assert!(has_lit && has_interp);
804    }
805
806    #[test]
807    fn directive_name() {
808        let p = parse_cst("#schema X { Int a: * }\n{ x: 1 }");
809        let doc = Document::cast(p.syntax()).unwrap();
810        let dirs: Vec<_> = doc.directives().collect();
811        assert_eq!(dirs.len(), 1);
812        assert_eq!(dirs[0].name().as_deref(), Some("schema"));
813    }
814
815    #[test]
816    fn match_arms() {
817        let p = parse_cst("{ f(x): x match { Int: 1, _ : 0 } }");
818        let doc = Document::cast(p.syntax()).unwrap();
819        let dict = match doc.root_expr().unwrap() {
820            Expr::Dict(d) => d,
821            _ => panic!(),
822        };
823        let cls = match dict.fields().next().and_then(|f| f.value()).unwrap() {
824            Expr::Closure(c) => c,
825            _ => panic!(),
826        };
827        let body = cls.body().unwrap();
828        let m = match body {
829            Expr::Match(m) => m,
830            _ => panic!(),
831        };
832        assert_eq!(m.arms().count(), 2);
833    }
834
835    #[test]
836    fn error_variant_for_partial_parse() {
837        // Force an ERROR child by feeding malformed bytes.
838        let p = parse_cst("{ broken @ # }");
839        let doc = Document::cast(p.syntax()).unwrap();
840        // Walk: anything kind-Error should round-trip via Expr::cast.
841        let any_error = doc
842            .syntax()
843            .descendants()
844            .filter_map(Expr::cast)
845            .any(|e| matches!(e, Expr::Error(_)));
846        assert!(any_error, "expected at least one Expr::Error variant");
847    }
848}