use crate::syntax::{SyntaxKind, SyntaxNode};
macro_rules! ast_node {
($(#[$meta:meta])* $name:ident, $kind:ident) => {
$(#[$meta])*
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct $name(SyntaxNode);
impl $name {
pub fn cast(node: SyntaxNode) -> Option<Self> {
if node.kind() == SyntaxKind::$kind {
Some(Self(node))
} else {
None
}
}
pub fn syntax(&self) -> &SyntaxNode {
&self.0
}
pub fn text(&self) -> String {
self.0.text().to_string()
}
}
};
}
ast_node!(
Document, DOCUMENT
);
ast_node!(
Directive, DIRECTIVE
);
ast_node!(
Decorator, DECORATOR
);
ast_node!(Dict, DICT);
ast_node!(DictField, DICT_FIELD);
ast_node!(List, LIST);
ast_node!(
Tuple,
TUPLE
);
ast_node!(Comprehension, COMPREHENSION);
ast_node!(Closure, CLOSURE);
ast_node!(ClosureParam, CLOSURE_PARAM);
ast_node!(CallExpr, CALL_EXPR);
ast_node!(CallArg, CALL_ARG);
ast_node!(BinaryExpr, BINARY_EXPR);
ast_node!(UnaryExpr, UNARY_EXPR);
ast_node!(TernaryExpr, TERNARY_EXPR);
ast_node!(ReferenceExpr, REFERENCE_EXPR);
ast_node!(VariableExpr, VARIABLE_EXPR);
ast_node!(WhereExpr, WHERE_EXPR);
ast_node!(MatchExpr, MATCH_EXPR);
ast_node!(MatchArm, MATCH_ARM);
ast_node!(VariantCtor, VARIANT_CTOR);
ast_node!(FString, F_STRING);
ast_node!(FStringInterpolation, F_STRING_INTERPOLATION);
ast_node!(SpreadExpr, SPREAD_EXPR);
ast_node!(TypeNode, TYPE_NODE);
ast_node!(
TupleType,
TUPLE_TYPE
);
ast_node!(
SchemaWith,
SCHEMA_WITH
);
ast_node!(
SchemaMethod,
SCHEMA_METHOD
);
ast_node!(Wildcard, WILDCARD);
ast_node!(Literal, LITERAL);
ast_node!(ErrorNode, ERROR);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Expr {
Literal(Literal),
Variable(VariableExpr),
Reference(ReferenceExpr),
Dict(Dict),
List(List),
Tuple(Tuple),
Spread(SpreadExpr),
Comprehension(Comprehension),
Binary(BinaryExpr),
Unary(UnaryExpr),
Ternary(TernaryExpr),
Call(CallExpr),
FString(FString),
Type(TypeNode),
Wildcard(Wildcard),
Where(WhereExpr),
Match(MatchExpr),
Closure(Closure),
VariantCtor(VariantCtor),
Error(ErrorNode),
}
impl Expr {
pub fn cast(node: SyntaxNode) -> Option<Self> {
Some(match node.kind() {
SyntaxKind::LITERAL => Self::Literal(Literal(node)),
SyntaxKind::VARIABLE_EXPR => Self::Variable(VariableExpr(node)),
SyntaxKind::REFERENCE_EXPR => Self::Reference(ReferenceExpr(node)),
SyntaxKind::DICT => Self::Dict(Dict(node)),
SyntaxKind::LIST => Self::List(List(node)),
SyntaxKind::TUPLE => Self::Tuple(Tuple(node)),
SyntaxKind::SPREAD_EXPR => Self::Spread(SpreadExpr(node)),
SyntaxKind::COMPREHENSION => Self::Comprehension(Comprehension(node)),
SyntaxKind::BINARY_EXPR => Self::Binary(BinaryExpr(node)),
SyntaxKind::UNARY_EXPR => Self::Unary(UnaryExpr(node)),
SyntaxKind::TERNARY_EXPR => Self::Ternary(TernaryExpr(node)),
SyntaxKind::CALL_EXPR => Self::Call(CallExpr(node)),
SyntaxKind::F_STRING => Self::FString(FString(node)),
SyntaxKind::TYPE_NODE => Self::Type(TypeNode(node)),
SyntaxKind::WILDCARD => Self::Wildcard(Wildcard(node)),
SyntaxKind::WHERE_EXPR => Self::Where(WhereExpr(node)),
SyntaxKind::MATCH_EXPR => Self::Match(MatchExpr(node)),
SyntaxKind::CLOSURE => Self::Closure(Closure(node)),
SyntaxKind::VARIANT_CTOR => Self::VariantCtor(VariantCtor(node)),
SyntaxKind::ERROR => Self::Error(ErrorNode(node)),
_ => return None,
})
}
pub fn syntax(&self) -> &SyntaxNode {
match self {
Self::Literal(n) => n.syntax(),
Self::Variable(n) => n.syntax(),
Self::Reference(n) => n.syntax(),
Self::Dict(n) => n.syntax(),
Self::List(n) => n.syntax(),
Self::Tuple(n) => n.syntax(),
Self::Spread(n) => n.syntax(),
Self::Comprehension(n) => n.syntax(),
Self::Binary(n) => n.syntax(),
Self::Unary(n) => n.syntax(),
Self::Ternary(n) => n.syntax(),
Self::Call(n) => n.syntax(),
Self::FString(n) => n.syntax(),
Self::Type(n) => n.syntax(),
Self::Wildcard(n) => n.syntax(),
Self::Where(n) => n.syntax(),
Self::Match(n) => n.syntax(),
Self::Closure(n) => n.syntax(),
Self::VariantCtor(n) => n.syntax(),
Self::Error(n) => n.syntax(),
}
}
pub fn text(&self) -> String {
self.syntax().text().to_string()
}
}
impl Document {
pub fn directives(&self) -> impl Iterator<Item = Directive> + '_ {
self.0.children().filter_map(Directive::cast)
}
pub fn decorators(&self) -> impl Iterator<Item = Decorator> + '_ {
self.0.children().filter_map(Decorator::cast)
}
pub fn root_expr(&self) -> Option<Expr> {
self.0.children().find_map(Expr::cast)
}
}
impl Directive {
pub fn name(&self) -> Option<String> {
self.0
.children_with_tokens()
.filter_map(|el| el.into_token())
.find(|t| t.kind() == SyntaxKind::IDENT)
.map(|t| t.text().to_string())
}
pub fn body_exprs(&self) -> impl Iterator<Item = Expr> + '_ {
self.0.children().filter_map(Expr::cast)
}
}
impl Decorator {
pub fn name(&self) -> Option<String> {
self.0
.children_with_tokens()
.filter_map(|el| el.into_token())
.find(|t| t.kind() == SyntaxKind::IDENT)
.map(|t| t.text().to_string())
}
pub fn args(&self) -> impl Iterator<Item = Expr> + '_ {
self.0
.children()
.find(|c| c.kind() == SyntaxKind::CALL_ARG)
.into_iter()
.flat_map(|n| n.children().filter_map(Expr::cast).collect::<Vec<_>>())
}
}
impl Dict {
pub fn fields(&self) -> impl Iterator<Item = DictField> + '_ {
self.0.children().filter_map(DictField::cast)
}
}
impl DictField {
pub fn key_text(&self) -> Option<String> {
self.0
.children_with_tokens()
.filter_map(|el| el.into_token())
.find(|t| t.kind() == SyntaxKind::IDENT || t.kind() == SyntaxKind::STRING)
.map(|t| t.text().to_string())
}
pub fn value(&self) -> Option<Expr> {
self.0.children().filter_map(Expr::cast).next()
}
}
impl List {
pub fn items(&self) -> impl Iterator<Item = Expr> + '_ {
self.0.children().filter_map(Expr::cast)
}
}
impl Tuple {
pub fn items(&self) -> impl Iterator<Item = Expr> + '_ {
self.0.children().filter_map(Expr::cast)
}
}
impl Comprehension {
pub fn parts(&self) -> Vec<Expr> {
self.0.children().filter_map(Expr::cast).collect()
}
pub fn binding(&self) -> Option<String> {
let mut after_for = false;
for el in self.0.children_with_tokens() {
if let Some(t) = el.as_token() {
if t.kind() == SyntaxKind::IDENT {
let s = t.text();
if after_for {
return Some(s.to_string());
}
if s == "for" {
after_for = true;
}
}
}
}
None
}
}
impl Closure {
pub fn params(&self) -> impl Iterator<Item = ClosureParam> + '_ {
self.0.children().filter_map(ClosureParam::cast)
}
pub fn return_type(&self) -> Option<TypeNode> {
let mut saw_arrow = false;
for el in self.0.children_with_tokens() {
if let Some(t) = el.as_token() {
if t.kind() == SyntaxKind::THIN_ARROW {
saw_arrow = true;
}
} else if let Some(n) = el.as_node() {
if saw_arrow && n.kind() == SyntaxKind::TYPE_NODE {
return TypeNode::cast(n.clone());
}
}
}
None
}
pub fn body(&self) -> Option<Expr> {
let mut last: Option<Expr> = None;
for child in self.0.children() {
if child.kind() == SyntaxKind::CLOSURE_PARAM || child.kind() == SyntaxKind::TYPE_NODE {
continue;
}
if let Some(e) = Expr::cast(child) {
last = Some(e);
}
}
last
}
}
impl ClosureParam {
pub fn name(&self) -> Option<String> {
self.0
.children_with_tokens()
.filter_map(|el| el.into_token())
.filter(|t| t.kind() == SyntaxKind::IDENT)
.last()
.map(|t| t.text().to_string())
}
pub fn type_hint(&self) -> Option<TypeNode> {
self.0.children().find_map(TypeNode::cast)
}
}
impl CallExpr {
pub fn callee(&self) -> Option<Expr> {
self.0.children().find_map(Expr::cast)
}
pub fn args(&self) -> impl Iterator<Item = Expr> + '_ {
self.0
.children()
.find(|c| c.kind() == SyntaxKind::CALL_ARG)
.into_iter()
.flat_map(|n| n.children().filter_map(Expr::cast).collect::<Vec<_>>())
}
}
impl BinaryExpr {
pub fn op_kind(&self) -> Option<SyntaxKind> {
self.0
.children_with_tokens()
.filter_map(|el| el.into_token())
.map(|t| t.kind())
.find(|k| {
matches!(
k,
SyntaxKind::PLUS
| SyntaxKind::MINUS
| SyntaxKind::STAR
| SyntaxKind::SLASH
| SyntaxKind::PERCENT
| SyntaxKind::PLUS_PLUS
| SyntaxKind::EQ_EQ
| SyntaxKind::BANG_EQ
| SyntaxKind::LT
| SyntaxKind::GT
| SyntaxKind::LT_EQ
| SyntaxKind::GT_EQ
| SyntaxKind::AMP_AMP
| SyntaxKind::PIPE_PIPE
| SyntaxKind::PIPE
)
})
}
pub fn lhs(&self) -> Option<Expr> {
self.0.children().find_map(Expr::cast)
}
pub fn rhs(&self) -> Option<Expr> {
self.0.children().filter_map(Expr::cast).nth(1)
}
}
impl UnaryExpr {
pub fn op_kind(&self) -> Option<SyntaxKind> {
self.0
.children_with_tokens()
.filter_map(|el| el.into_token())
.map(|t| t.kind())
.find(|k| matches!(k, SyntaxKind::MINUS | SyntaxKind::BANG | SyntaxKind::PLUS))
}
pub fn operand(&self) -> Option<Expr> {
self.0.children().find_map(Expr::cast)
}
}
impl TernaryExpr {
pub fn cond(&self) -> Option<Expr> {
self.0.children().find_map(Expr::cast)
}
pub fn then(&self) -> Option<Expr> {
self.0.children().filter_map(Expr::cast).nth(1)
}
pub fn els(&self) -> Option<Expr> {
self.0.children().filter_map(Expr::cast).nth(2)
}
}
impl ReferenceExpr {
pub fn base_name(&self) -> Option<String> {
self.0
.children_with_tokens()
.filter_map(|el| el.into_token())
.find(|t| t.kind() == SyntaxKind::IDENT)
.map(|t| t.text().to_string())
}
pub fn path_text(&self) -> String {
self.text()
}
}
impl VariableExpr {
pub fn segments(&self) -> Vec<String> {
self.0
.children_with_tokens()
.filter_map(|el| el.into_token())
.filter(|t| t.kind() == SyntaxKind::IDENT)
.map(|t| t.text().to_string())
.collect()
}
}
impl Literal {
pub fn kind(&self) -> Option<SyntaxKind> {
self.0
.children_with_tokens()
.filter_map(|el| el.into_token())
.map(|t| t.kind())
.find(|k| {
matches!(
k,
SyntaxKind::NUMBER | SyntaxKind::STRING | SyntaxKind::IDENT
)
})
}
pub fn value_text(&self) -> String {
self.text()
}
}
impl WhereExpr {
pub fn expr(&self) -> Option<Expr> {
self.0.children().find_map(Expr::cast)
}
pub fn bindings(&self) -> Option<Dict> {
self.0.children().filter_map(Dict::cast).next()
}
}
impl MatchExpr {
pub fn scrutinee(&self) -> Option<Expr> {
self.0.children().find_map(Expr::cast)
}
pub fn arms(&self) -> impl Iterator<Item = MatchArm> + '_ {
self.0.children().filter_map(MatchArm::cast)
}
}
impl MatchArm {
pub fn pattern(&self) -> Option<Expr> {
self.0.children().find_map(Expr::cast)
}
pub fn body(&self) -> Option<Expr> {
self.0.children().filter_map(Expr::cast).nth(1)
}
}
impl SpreadExpr {
pub fn inner(&self) -> Option<Expr> {
self.0.children().find_map(Expr::cast)
}
}
impl VariantCtor {
pub fn body(&self) -> Option<Dict> {
self.0.children().find_map(Dict::cast)
}
}
impl FString {
pub fn parts(&self) -> Vec<FStringPart> {
let mut out = Vec::new();
for el in self.0.children_with_tokens() {
if let Some(t) = el.as_token() {
if t.kind() == SyntaxKind::F_STRING_LITERAL {
out.push(FStringPart::Literal(t.text().to_string()));
}
} else if let Some(n) = el.as_node() {
if let Some(interp) = FStringInterpolation::cast(n.clone()) {
out.push(FStringPart::Interpolation(interp));
}
}
}
out
}
}
impl FStringInterpolation {
pub fn expr(&self) -> Option<Expr> {
self.0.children().find_map(Expr::cast)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum FStringPart {
Literal(String),
Interpolation(FStringInterpolation),
}
impl TypeNode {
pub fn path_text(&self) -> Vec<String> {
let mut out = Vec::new();
for el in self.0.children_with_tokens() {
if let Some(t) = el.as_token() {
match t.kind() {
SyntaxKind::LT => break,
SyntaxKind::QUESTION => break,
SyntaxKind::DOT => continue,
SyntaxKind::IDENT | SyntaxKind::STRING => out.push(t.text().to_string()),
_ => {}
}
} else {
break;
}
}
out
}
pub fn generics(&self) -> impl Iterator<Item = TypeNode> + '_ {
self.0.children().filter_map(TypeNode::cast)
}
pub fn is_optional(&self) -> bool {
self.0
.children_with_tokens()
.filter_map(|el| el.into_token())
.any(|t| t.kind() == SyntaxKind::QUESTION)
}
}
pub fn document_of(syntax: SyntaxNode) -> Option<Document> {
Document::cast(syntax)
}
pub use crate::syntax::SyntaxToken as _Token;
#[cfg(test)]
mod tests {
use super::*;
use crate::cst::parse_cst;
#[test]
fn document_round_trip() {
let p = parse_cst("{ a: 1, b: 2 }");
let doc = Document::cast(p.syntax()).expect("DOCUMENT kind");
assert!(doc.root_expr().is_some());
}
#[test]
fn dict_fields() {
let p = parse_cst("{ alice: 1, bob: 2 }");
let doc = Document::cast(p.syntax()).unwrap();
let dict = match doc.root_expr().unwrap() {
Expr::Dict(d) => d,
_ => panic!(),
};
let keys: Vec<_> = dict.fields().filter_map(|f| f.key_text()).collect();
assert_eq!(keys, vec!["alice".to_string(), "bob".to_string()]);
}
#[test]
fn binary_op_kind() {
let p = parse_cst("{ x: 1 + 2 }");
let doc = Document::cast(p.syntax()).unwrap();
let dict = match doc.root_expr().unwrap() {
Expr::Dict(d) => d,
_ => panic!(),
};
let value = dict.fields().next().and_then(|f| f.value()).unwrap();
let bin = match value {
Expr::Binary(b) => b,
other => panic!("not binary: {other:?}"),
};
assert_eq!(bin.op_kind(), Some(SyntaxKind::PLUS));
assert!(bin.lhs().is_some());
assert!(bin.rhs().is_some());
}
#[test]
fn closure_typed_params() {
let p = parse_cst("{ add(Int a, Int b): a + b }");
let doc = Document::cast(p.syntax()).unwrap();
let dict = match doc.root_expr().unwrap() {
Expr::Dict(d) => d,
_ => panic!(),
};
let cls = match dict.fields().next().and_then(|f| f.value()).unwrap() {
Expr::Closure(c) => c,
other => panic!("not closure: {other:?}"),
};
let params: Vec<_> = cls.params().collect();
assert_eq!(params.len(), 2);
assert_eq!(params[0].name().as_deref(), Some("a"));
assert!(params[0].type_hint().is_some());
}
#[test]
fn f_string_parts() {
let p = parse_cst(r#"{ msg: f"hi ${name}!" }"#);
let doc = Document::cast(p.syntax()).unwrap();
let dict = match doc.root_expr().unwrap() {
Expr::Dict(d) => d,
_ => panic!(),
};
let fs = match dict.fields().next().and_then(|f| f.value()).unwrap() {
Expr::FString(f) => f,
_ => panic!(),
};
let parts = fs.parts();
let mut has_lit = false;
let mut has_interp = false;
for p in &parts {
match p {
FStringPart::Literal(_) => has_lit = true,
FStringPart::Interpolation(_) => has_interp = true,
}
}
assert!(has_lit && has_interp);
}
#[test]
fn directive_name() {
let p = parse_cst("#schema X { Int a: * }\n{ x: 1 }");
let doc = Document::cast(p.syntax()).unwrap();
let dirs: Vec<_> = doc.directives().collect();
assert_eq!(dirs.len(), 1);
assert_eq!(dirs[0].name().as_deref(), Some("schema"));
}
#[test]
fn match_arms() {
let p = parse_cst("{ f(x): x match { Int: 1, _ : 0 } }");
let doc = Document::cast(p.syntax()).unwrap();
let dict = match doc.root_expr().unwrap() {
Expr::Dict(d) => d,
_ => panic!(),
};
let cls = match dict.fields().next().and_then(|f| f.value()).unwrap() {
Expr::Closure(c) => c,
_ => panic!(),
};
let body = cls.body().unwrap();
let m = match body {
Expr::Match(m) => m,
_ => panic!(),
};
assert_eq!(m.arms().count(), 2);
}
#[test]
fn error_variant_for_partial_parse() {
let p = parse_cst("{ broken @ # }");
let doc = Document::cast(p.syntax()).unwrap();
let any_error = doc
.syntax()
.descendants()
.filter_map(Expr::cast)
.any(|e| matches!(e, Expr::Error(_)));
assert!(any_error, "expected at least one Expr::Error variant");
}
}