1use crate::syntax::{SyntaxKind, SyntaxNode};
35
36macro_rules! ast_node {
42 ($(#[$meta:meta])* $name:ident, $kind:ident) => {
43 $(#[$meta])*
44 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
45 pub struct $name(SyntaxNode);
46
47 impl $name {
48 pub fn cast(node: SyntaxNode) -> Option<Self> {
51 if node.kind() == SyntaxKind::$kind {
52 Some(Self(node))
53 } else {
54 None
55 }
56 }
57
58 pub fn syntax(&self) -> &SyntaxNode {
61 &self.0
62 }
63
64 pub fn text(&self) -> String {
67 self.0.text().to_string()
68 }
69 }
70 };
71}
72
73ast_node!(
74 Document, DOCUMENT
77);
78
79ast_node!(
80 Directive, DIRECTIVE
83);
84
85ast_node!(
86 Decorator, DECORATOR
88);
89
90ast_node!(Dict, DICT);
91ast_node!(DictField, DICT_FIELD);
92ast_node!(List, LIST);
93ast_node!(
94 Tuple,
98 TUPLE
99);
100ast_node!(Comprehension, COMPREHENSION);
101ast_node!(Closure, CLOSURE);
102ast_node!(ClosureParam, CLOSURE_PARAM);
103ast_node!(CallExpr, CALL_EXPR);
104ast_node!(CallArg, CALL_ARG);
105ast_node!(BinaryExpr, BINARY_EXPR);
106ast_node!(UnaryExpr, UNARY_EXPR);
107ast_node!(TernaryExpr, TERNARY_EXPR);
108ast_node!(ReferenceExpr, REFERENCE_EXPR);
109ast_node!(VariableExpr, VARIABLE_EXPR);
110ast_node!(WhereExpr, WHERE_EXPR);
111ast_node!(MatchExpr, MATCH_EXPR);
112ast_node!(MatchArm, MATCH_ARM);
113ast_node!(VariantCtor, VARIANT_CTOR);
114ast_node!(FString, F_STRING);
115ast_node!(FStringInterpolation, F_STRING_INTERPOLATION);
116ast_node!(SpreadExpr, SPREAD_EXPR);
117ast_node!(TypeNode, TYPE_NODE);
118ast_node!(
119 TupleType,
123 TUPLE_TYPE
124);
125ast_node!(
126 SchemaWith,
129 SCHEMA_WITH
130);
131ast_node!(
132 SchemaMethod,
137 SCHEMA_METHOD
138);
139ast_node!(Wildcard, WILDCARD);
140ast_node!(Literal, LITERAL);
141ast_node!(ErrorNode, ERROR);
142
143#[derive(Debug, Clone, PartialEq, Eq, Hash)]
158pub enum Expr {
159 Literal(Literal),
160 Variable(VariableExpr),
161 Reference(ReferenceExpr),
162 Dict(Dict),
163 List(List),
164 Tuple(Tuple),
165 Spread(SpreadExpr),
166 Comprehension(Comprehension),
167 Binary(BinaryExpr),
168 Unary(UnaryExpr),
169 Ternary(TernaryExpr),
170 Call(CallExpr),
171 FString(FString),
172 Type(TypeNode),
173 Wildcard(Wildcard),
174 Where(WhereExpr),
175 Match(MatchExpr),
176 Closure(Closure),
177 VariantCtor(VariantCtor),
178 Error(ErrorNode),
183}
184
185impl Expr {
186 pub fn cast(node: SyntaxNode) -> Option<Self> {
190 Some(match node.kind() {
191 SyntaxKind::LITERAL => Self::Literal(Literal(node)),
192 SyntaxKind::VARIABLE_EXPR => Self::Variable(VariableExpr(node)),
193 SyntaxKind::REFERENCE_EXPR => Self::Reference(ReferenceExpr(node)),
194 SyntaxKind::DICT => Self::Dict(Dict(node)),
195 SyntaxKind::LIST => Self::List(List(node)),
196 SyntaxKind::TUPLE => Self::Tuple(Tuple(node)),
197 SyntaxKind::SPREAD_EXPR => Self::Spread(SpreadExpr(node)),
198 SyntaxKind::COMPREHENSION => Self::Comprehension(Comprehension(node)),
199 SyntaxKind::BINARY_EXPR => Self::Binary(BinaryExpr(node)),
200 SyntaxKind::UNARY_EXPR => Self::Unary(UnaryExpr(node)),
201 SyntaxKind::TERNARY_EXPR => Self::Ternary(TernaryExpr(node)),
202 SyntaxKind::CALL_EXPR => Self::Call(CallExpr(node)),
203 SyntaxKind::F_STRING => Self::FString(FString(node)),
204 SyntaxKind::TYPE_NODE => Self::Type(TypeNode(node)),
205 SyntaxKind::WILDCARD => Self::Wildcard(Wildcard(node)),
206 SyntaxKind::WHERE_EXPR => Self::Where(WhereExpr(node)),
207 SyntaxKind::MATCH_EXPR => Self::Match(MatchExpr(node)),
208 SyntaxKind::CLOSURE => Self::Closure(Closure(node)),
209 SyntaxKind::VARIANT_CTOR => Self::VariantCtor(VariantCtor(node)),
210 SyntaxKind::ERROR => Self::Error(ErrorNode(node)),
211 _ => return None,
212 })
213 }
214
215 pub fn syntax(&self) -> &SyntaxNode {
217 match self {
218 Self::Literal(n) => n.syntax(),
219 Self::Variable(n) => n.syntax(),
220 Self::Reference(n) => n.syntax(),
221 Self::Dict(n) => n.syntax(),
222 Self::List(n) => n.syntax(),
223 Self::Tuple(n) => n.syntax(),
224 Self::Spread(n) => n.syntax(),
225 Self::Comprehension(n) => n.syntax(),
226 Self::Binary(n) => n.syntax(),
227 Self::Unary(n) => n.syntax(),
228 Self::Ternary(n) => n.syntax(),
229 Self::Call(n) => n.syntax(),
230 Self::FString(n) => n.syntax(),
231 Self::Type(n) => n.syntax(),
232 Self::Wildcard(n) => n.syntax(),
233 Self::Where(n) => n.syntax(),
234 Self::Match(n) => n.syntax(),
235 Self::Closure(n) => n.syntax(),
236 Self::VariantCtor(n) => n.syntax(),
237 Self::Error(n) => n.syntax(),
238 }
239 }
240
241 pub fn text(&self) -> String {
243 self.syntax().text().to_string()
244 }
245}
246
247impl Document {
254 pub fn directives(&self) -> impl Iterator<Item = Directive> + '_ {
256 self.0.children().filter_map(Directive::cast)
257 }
258
259 pub fn decorators(&self) -> impl Iterator<Item = Decorator> + '_ {
261 self.0.children().filter_map(Decorator::cast)
262 }
263
264 pub fn root_expr(&self) -> Option<Expr> {
267 self.0.children().find_map(Expr::cast)
268 }
269}
270
271impl Directive {
272 pub fn name(&self) -> Option<String> {
275 self.0
278 .children_with_tokens()
279 .filter_map(|el| el.into_token())
280 .find(|t| t.kind() == SyntaxKind::IDENT)
281 .map(|t| t.text().to_string())
282 }
283
284 pub fn body_exprs(&self) -> impl Iterator<Item = Expr> + '_ {
290 self.0.children().filter_map(Expr::cast)
291 }
292}
293
294impl Decorator {
295 pub fn name(&self) -> Option<String> {
296 self.0
297 .children_with_tokens()
298 .filter_map(|el| el.into_token())
299 .find(|t| t.kind() == SyntaxKind::IDENT)
300 .map(|t| t.text().to_string())
301 }
302
303 pub fn args(&self) -> impl Iterator<Item = Expr> + '_ {
304 self.0
307 .children()
308 .find(|c| c.kind() == SyntaxKind::CALL_ARG)
309 .into_iter()
310 .flat_map(|n| n.children().filter_map(Expr::cast).collect::<Vec<_>>())
311 }
312}
313
314impl Dict {
315 pub fn fields(&self) -> impl Iterator<Item = DictField> + '_ {
316 self.0.children().filter_map(DictField::cast)
317 }
318}
319
320impl DictField {
321 pub fn key_text(&self) -> Option<String> {
325 self.0
326 .children_with_tokens()
327 .filter_map(|el| el.into_token())
328 .find(|t| t.kind() == SyntaxKind::IDENT || t.kind() == SyntaxKind::STRING)
329 .map(|t| t.text().to_string())
330 }
331
332 pub fn value(&self) -> Option<Expr> {
336 self.0.children().filter_map(Expr::cast).next()
337 }
338}
339
340impl List {
341 pub fn items(&self) -> impl Iterator<Item = Expr> + '_ {
342 self.0.children().filter_map(Expr::cast)
343 }
344}
345
346impl Tuple {
347 pub fn items(&self) -> impl Iterator<Item = Expr> + '_ {
349 self.0.children().filter_map(Expr::cast)
350 }
351}
352
353impl Comprehension {
354 pub fn parts(&self) -> Vec<Expr> {
358 self.0.children().filter_map(Expr::cast).collect()
359 }
360
361 pub fn binding(&self) -> Option<String> {
364 let mut after_for = false;
365 for el in self.0.children_with_tokens() {
366 if let Some(t) = el.as_token() {
367 if t.kind() == SyntaxKind::IDENT {
368 let s = t.text();
369 if after_for {
370 return Some(s.to_string());
371 }
372 if s == "for" {
373 after_for = true;
374 }
375 }
376 }
377 }
378 None
379 }
380}
381
382impl Closure {
383 pub fn params(&self) -> impl Iterator<Item = ClosureParam> + '_ {
384 self.0.children().filter_map(ClosureParam::cast)
385 }
386
387 pub fn return_type(&self) -> Option<TypeNode> {
390 let mut saw_arrow = false;
391 for el in self.0.children_with_tokens() {
392 if let Some(t) = el.as_token() {
393 if t.kind() == SyntaxKind::THIN_ARROW {
394 saw_arrow = true;
395 }
396 } else if let Some(n) = el.as_node() {
397 if saw_arrow && n.kind() == SyntaxKind::TYPE_NODE {
398 return TypeNode::cast(n.clone());
399 }
400 }
401 }
402 None
403 }
404
405 pub fn body(&self) -> Option<Expr> {
408 let mut last: Option<Expr> = None;
413 for child in self.0.children() {
414 if child.kind() == SyntaxKind::CLOSURE_PARAM || child.kind() == SyntaxKind::TYPE_NODE {
415 continue;
416 }
417 if let Some(e) = Expr::cast(child) {
418 last = Some(e);
419 }
420 }
421 last
422 }
423}
424
425impl ClosureParam {
426 pub fn name(&self) -> Option<String> {
427 self.0
430 .children_with_tokens()
431 .filter_map(|el| el.into_token())
432 .filter(|t| t.kind() == SyntaxKind::IDENT)
433 .last()
434 .map(|t| t.text().to_string())
435 }
436
437 pub fn type_hint(&self) -> Option<TypeNode> {
438 self.0.children().find_map(TypeNode::cast)
439 }
440}
441
442impl CallExpr {
443 pub fn callee(&self) -> Option<Expr> {
447 self.0.children().find_map(Expr::cast)
448 }
449
450 pub fn args(&self) -> impl Iterator<Item = Expr> + '_ {
452 self.0
453 .children()
454 .find(|c| c.kind() == SyntaxKind::CALL_ARG)
455 .into_iter()
456 .flat_map(|n| n.children().filter_map(Expr::cast).collect::<Vec<_>>())
457 }
458}
459
460impl BinaryExpr {
461 pub fn op_kind(&self) -> Option<SyntaxKind> {
465 self.0
466 .children_with_tokens()
467 .filter_map(|el| el.into_token())
468 .map(|t| t.kind())
469 .find(|k| {
470 matches!(
471 k,
472 SyntaxKind::PLUS
473 | SyntaxKind::MINUS
474 | SyntaxKind::STAR
475 | SyntaxKind::SLASH
476 | SyntaxKind::PERCENT
477 | SyntaxKind::PLUS_PLUS
478 | SyntaxKind::EQ_EQ
479 | SyntaxKind::BANG_EQ
480 | SyntaxKind::LT
481 | SyntaxKind::GT
482 | SyntaxKind::LT_EQ
483 | SyntaxKind::GT_EQ
484 | SyntaxKind::AMP_AMP
485 | SyntaxKind::PIPE_PIPE
486 | SyntaxKind::PIPE
487 )
488 })
489 }
490
491 pub fn lhs(&self) -> Option<Expr> {
492 self.0.children().find_map(Expr::cast)
493 }
494
495 pub fn rhs(&self) -> Option<Expr> {
496 self.0.children().filter_map(Expr::cast).nth(1)
497 }
498}
499
500impl UnaryExpr {
501 pub fn op_kind(&self) -> Option<SyntaxKind> {
503 self.0
504 .children_with_tokens()
505 .filter_map(|el| el.into_token())
506 .map(|t| t.kind())
507 .find(|k| matches!(k, SyntaxKind::MINUS | SyntaxKind::BANG | SyntaxKind::PLUS))
508 }
509
510 pub fn operand(&self) -> Option<Expr> {
511 self.0.children().find_map(Expr::cast)
512 }
513}
514
515impl TernaryExpr {
516 pub fn cond(&self) -> Option<Expr> {
517 self.0.children().find_map(Expr::cast)
518 }
519
520 pub fn then(&self) -> Option<Expr> {
521 self.0.children().filter_map(Expr::cast).nth(1)
522 }
523
524 pub fn els(&self) -> Option<Expr> {
525 self.0.children().filter_map(Expr::cast).nth(2)
526 }
527}
528
529impl ReferenceExpr {
530 pub fn base_name(&self) -> Option<String> {
534 self.0
535 .children_with_tokens()
536 .filter_map(|el| el.into_token())
537 .find(|t| t.kind() == SyntaxKind::IDENT)
538 .map(|t| t.text().to_string())
539 }
540
541 pub fn path_text(&self) -> String {
544 self.text()
545 }
546}
547
548impl VariableExpr {
549 pub fn segments(&self) -> Vec<String> {
551 self.0
552 .children_with_tokens()
553 .filter_map(|el| el.into_token())
554 .filter(|t| t.kind() == SyntaxKind::IDENT)
555 .map(|t| t.text().to_string())
556 .collect()
557 }
558}
559
560impl Literal {
561 pub fn kind(&self) -> Option<SyntaxKind> {
565 self.0
566 .children_with_tokens()
567 .filter_map(|el| el.into_token())
568 .map(|t| t.kind())
569 .find(|k| {
570 matches!(
571 k,
572 SyntaxKind::NUMBER | SyntaxKind::STRING | SyntaxKind::IDENT
573 )
574 })
575 }
576
577 pub fn value_text(&self) -> String {
580 self.text()
581 }
582}
583
584impl WhereExpr {
585 pub fn expr(&self) -> Option<Expr> {
587 self.0.children().find_map(Expr::cast)
588 }
589
590 pub fn bindings(&self) -> Option<Dict> {
592 self.0.children().filter_map(Dict::cast).next()
593 }
594}
595
596impl MatchExpr {
597 pub fn scrutinee(&self) -> Option<Expr> {
599 self.0.children().find_map(Expr::cast)
600 }
601
602 pub fn arms(&self) -> impl Iterator<Item = MatchArm> + '_ {
603 self.0.children().filter_map(MatchArm::cast)
604 }
605}
606
607impl MatchArm {
608 pub fn pattern(&self) -> Option<Expr> {
611 self.0.children().find_map(Expr::cast)
612 }
613
614 pub fn body(&self) -> Option<Expr> {
616 self.0.children().filter_map(Expr::cast).nth(1)
617 }
618}
619
620impl SpreadExpr {
621 pub fn inner(&self) -> Option<Expr> {
623 self.0.children().find_map(Expr::cast)
624 }
625}
626
627impl VariantCtor {
628 pub fn body(&self) -> Option<Dict> {
630 self.0.children().find_map(Dict::cast)
631 }
632}
633
634impl FString {
635 pub fn parts(&self) -> Vec<FStringPart> {
638 let mut out = Vec::new();
639 for el in self.0.children_with_tokens() {
640 if let Some(t) = el.as_token() {
641 if t.kind() == SyntaxKind::F_STRING_LITERAL {
642 out.push(FStringPart::Literal(t.text().to_string()));
643 }
644 } else if let Some(n) = el.as_node() {
645 if let Some(interp) = FStringInterpolation::cast(n.clone()) {
646 out.push(FStringPart::Interpolation(interp));
647 }
648 }
649 }
650 out
651 }
652}
653
654impl FStringInterpolation {
655 pub fn expr(&self) -> Option<Expr> {
657 self.0.children().find_map(Expr::cast)
658 }
659}
660
661#[derive(Debug, Clone, PartialEq, Eq, Hash)]
664pub enum FStringPart {
665 Literal(String),
666 Interpolation(FStringInterpolation),
667}
668
669impl TypeNode {
670 pub fn path_text(&self) -> Vec<String> {
674 let mut out = Vec::new();
675 for el in self.0.children_with_tokens() {
676 if let Some(t) = el.as_token() {
677 match t.kind() {
678 SyntaxKind::LT => break,
679 SyntaxKind::QUESTION => break,
680 SyntaxKind::DOT => continue,
681 SyntaxKind::IDENT | SyntaxKind::STRING => out.push(t.text().to_string()),
682 _ => {}
683 }
684 } else {
685 break;
686 }
687 }
688 out
689 }
690
691 pub fn generics(&self) -> impl Iterator<Item = TypeNode> + '_ {
694 self.0.children().filter_map(TypeNode::cast)
695 }
696
697 pub fn is_optional(&self) -> bool {
699 self.0
700 .children_with_tokens()
701 .filter_map(|el| el.into_token())
702 .any(|t| t.kind() == SyntaxKind::QUESTION)
703 }
704}
705
706pub fn document_of(syntax: SyntaxNode) -> Option<Document> {
715 Document::cast(syntax)
716}
717
718pub use crate::syntax::SyntaxToken as _Token;
721
722#[cfg(test)]
723mod tests {
724 use super::*;
725 use crate::cst::parse_cst;
726
727 #[test]
728 fn document_round_trip() {
729 let p = parse_cst("{ a: 1, b: 2 }");
730 let doc = Document::cast(p.syntax()).expect("DOCUMENT kind");
731 assert!(doc.root_expr().is_some());
732 }
733
734 #[test]
735 fn dict_fields() {
736 let p = parse_cst("{ alice: 1, bob: 2 }");
737 let doc = Document::cast(p.syntax()).unwrap();
738 let dict = match doc.root_expr().unwrap() {
739 Expr::Dict(d) => d,
740 _ => panic!(),
741 };
742 let keys: Vec<_> = dict.fields().filter_map(|f| f.key_text()).collect();
743 assert_eq!(keys, vec!["alice".to_string(), "bob".to_string()]);
744 }
745
746 #[test]
747 fn binary_op_kind() {
748 let p = parse_cst("{ x: 1 + 2 }");
749 let doc = Document::cast(p.syntax()).unwrap();
750 let dict = match doc.root_expr().unwrap() {
751 Expr::Dict(d) => d,
752 _ => panic!(),
753 };
754 let value = dict.fields().next().and_then(|f| f.value()).unwrap();
755 let bin = match value {
756 Expr::Binary(b) => b,
757 other => panic!("not binary: {other:?}"),
758 };
759 assert_eq!(bin.op_kind(), Some(SyntaxKind::PLUS));
760 assert!(bin.lhs().is_some());
761 assert!(bin.rhs().is_some());
762 }
763
764 #[test]
765 fn closure_typed_params() {
766 let p = parse_cst("{ add(Int a, Int b): a + b }");
767 let doc = Document::cast(p.syntax()).unwrap();
768 let dict = match doc.root_expr().unwrap() {
769 Expr::Dict(d) => d,
770 _ => panic!(),
771 };
772 let cls = match dict.fields().next().and_then(|f| f.value()).unwrap() {
773 Expr::Closure(c) => c,
774 other => panic!("not closure: {other:?}"),
775 };
776 let params: Vec<_> = cls.params().collect();
777 assert_eq!(params.len(), 2);
778 assert_eq!(params[0].name().as_deref(), Some("a"));
779 assert!(params[0].type_hint().is_some());
780 }
781
782 #[test]
783 fn f_string_parts() {
784 let p = parse_cst(r#"{ msg: f"hi ${name}!" }"#);
785 let doc = Document::cast(p.syntax()).unwrap();
786 let dict = match doc.root_expr().unwrap() {
787 Expr::Dict(d) => d,
788 _ => panic!(),
789 };
790 let fs = match dict.fields().next().and_then(|f| f.value()).unwrap() {
791 Expr::FString(f) => f,
792 _ => panic!(),
793 };
794 let parts = fs.parts();
795 let mut has_lit = false;
796 let mut has_interp = false;
797 for p in &parts {
798 match p {
799 FStringPart::Literal(_) => has_lit = true,
800 FStringPart::Interpolation(_) => has_interp = true,
801 }
802 }
803 assert!(has_lit && has_interp);
804 }
805
806 #[test]
807 fn directive_name() {
808 let p = parse_cst("#schema X { Int a: * }\n{ x: 1 }");
809 let doc = Document::cast(p.syntax()).unwrap();
810 let dirs: Vec<_> = doc.directives().collect();
811 assert_eq!(dirs.len(), 1);
812 assert_eq!(dirs[0].name().as_deref(), Some("schema"));
813 }
814
815 #[test]
816 fn match_arms() {
817 let p = parse_cst("{ f(x): x match { Int: 1, _ : 0 } }");
818 let doc = Document::cast(p.syntax()).unwrap();
819 let dict = match doc.root_expr().unwrap() {
820 Expr::Dict(d) => d,
821 _ => panic!(),
822 };
823 let cls = match dict.fields().next().and_then(|f| f.value()).unwrap() {
824 Expr::Closure(c) => c,
825 _ => panic!(),
826 };
827 let body = cls.body().unwrap();
828 let m = match body {
829 Expr::Match(m) => m,
830 _ => panic!(),
831 };
832 assert_eq!(m.arms().count(), 2);
833 }
834
835 #[test]
836 fn error_variant_for_partial_parse() {
837 let p = parse_cst("{ broken @ # }");
839 let doc = Document::cast(p.syntax()).unwrap();
840 let any_error = doc
842 .syntax()
843 .descendants()
844 .filter_map(Expr::cast)
845 .any(|e| matches!(e, Expr::Error(_)));
846 assert!(any_error, "expected at least one Expr::Error variant");
847 }
848}