use std::num::NonZeroUsize;
use std::ops::Deref;
use ecow::EcoString;
use unscanny::Scanner;
use crate::{is_newline, Span, SyntaxKind, SyntaxNode};
pub trait AstNode<'a>: Sized {
fn from_untyped(node: &'a SyntaxNode) -> Option<Self>;
fn to_untyped(self) -> &'a SyntaxNode;
fn span(self) -> Span {
self.to_untyped().span()
}
}
macro_rules! node {
($(#[$attr:meta])* $name:ident) => {
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
#[repr(transparent)]
$(#[$attr])*
pub struct $name<'a>(&'a SyntaxNode);
impl<'a> AstNode<'a> for $name<'a> {
#[inline]
fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
if node.kind() == SyntaxKind::$name {
Some(Self(node))
} else {
Option::None
}
}
#[inline]
fn to_untyped(self) -> &'a SyntaxNode {
self.0
}
}
impl Default for $name<'_> {
#[inline]
fn default() -> Self {
static PLACEHOLDER: SyntaxNode
= SyntaxNode::placeholder(SyntaxKind::$name);
Self(&PLACEHOLDER)
}
}
};
}
node! {
Markup
}
impl<'a> Markup<'a> {
pub fn exprs(self) -> impl DoubleEndedIterator<Item = Expr<'a>> {
let mut was_stmt = false;
self.0
.children()
.filter(move |node| {
let kind = node.kind();
let keep = !was_stmt || node.kind() != SyntaxKind::Space;
was_stmt = kind.is_stmt();
keep
})
.filter_map(Expr::cast_with_space)
}
}
#[derive(Debug, Copy, Clone, Hash)]
pub enum Expr<'a> {
Text(Text<'a>),
Space(Space<'a>),
Linebreak(Linebreak<'a>),
Parbreak(Parbreak<'a>),
Escape(Escape<'a>),
Shorthand(Shorthand<'a>),
SmartQuote(SmartQuote<'a>),
Strong(Strong<'a>),
Emph(Emph<'a>),
Raw(Raw<'a>),
Link(Link<'a>),
Label(Label<'a>),
Ref(Ref<'a>),
Heading(Heading<'a>),
List(ListItem<'a>),
Enum(EnumItem<'a>),
Term(TermItem<'a>),
Equation(Equation<'a>),
Math(Math<'a>),
MathIdent(MathIdent<'a>),
MathShorthand(MathShorthand<'a>),
MathAlignPoint(MathAlignPoint<'a>),
MathDelimited(MathDelimited<'a>),
MathAttach(MathAttach<'a>),
MathPrimes(MathPrimes<'a>),
MathFrac(MathFrac<'a>),
MathRoot(MathRoot<'a>),
Ident(Ident<'a>),
None(None<'a>),
Auto(Auto<'a>),
Bool(Bool<'a>),
Int(Int<'a>),
Float(Float<'a>),
Numeric(Numeric<'a>),
Str(Str<'a>),
Code(CodeBlock<'a>),
Content(ContentBlock<'a>),
Parenthesized(Parenthesized<'a>),
Array(Array<'a>),
Dict(Dict<'a>),
Unary(Unary<'a>),
Binary(Binary<'a>),
FieldAccess(FieldAccess<'a>),
FuncCall(FuncCall<'a>),
Closure(Closure<'a>),
Let(LetBinding<'a>),
DestructAssign(DestructAssignment<'a>),
Set(SetRule<'a>),
Show(ShowRule<'a>),
Contextual(Contextual<'a>),
Conditional(Conditional<'a>),
While(WhileLoop<'a>),
For(ForLoop<'a>),
Import(ModuleImport<'a>),
Include(ModuleInclude<'a>),
Break(LoopBreak<'a>),
Continue(LoopContinue<'a>),
Return(FuncReturn<'a>),
}
impl<'a> Expr<'a> {
fn cast_with_space(node: &'a SyntaxNode) -> Option<Self> {
match node.kind() {
SyntaxKind::Space => node.cast().map(Self::Space),
_ => Self::from_untyped(node),
}
}
}
impl<'a> AstNode<'a> for Expr<'a> {
fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
match node.kind() {
SyntaxKind::Linebreak => node.cast().map(Self::Linebreak),
SyntaxKind::Parbreak => node.cast().map(Self::Parbreak),
SyntaxKind::Text => node.cast().map(Self::Text),
SyntaxKind::Escape => node.cast().map(Self::Escape),
SyntaxKind::Shorthand => node.cast().map(Self::Shorthand),
SyntaxKind::SmartQuote => node.cast().map(Self::SmartQuote),
SyntaxKind::Strong => node.cast().map(Self::Strong),
SyntaxKind::Emph => node.cast().map(Self::Emph),
SyntaxKind::Raw => node.cast().map(Self::Raw),
SyntaxKind::Link => node.cast().map(Self::Link),
SyntaxKind::Label => node.cast().map(Self::Label),
SyntaxKind::Ref => node.cast().map(Self::Ref),
SyntaxKind::Heading => node.cast().map(Self::Heading),
SyntaxKind::ListItem => node.cast().map(Self::List),
SyntaxKind::EnumItem => node.cast().map(Self::Enum),
SyntaxKind::TermItem => node.cast().map(Self::Term),
SyntaxKind::Equation => node.cast().map(Self::Equation),
SyntaxKind::Math => node.cast().map(Self::Math),
SyntaxKind::MathIdent => node.cast().map(Self::MathIdent),
SyntaxKind::MathShorthand => node.cast().map(Self::MathShorthand),
SyntaxKind::MathAlignPoint => node.cast().map(Self::MathAlignPoint),
SyntaxKind::MathDelimited => node.cast().map(Self::MathDelimited),
SyntaxKind::MathAttach => node.cast().map(Self::MathAttach),
SyntaxKind::MathPrimes => node.cast().map(Self::MathPrimes),
SyntaxKind::MathFrac => node.cast().map(Self::MathFrac),
SyntaxKind::MathRoot => node.cast().map(Self::MathRoot),
SyntaxKind::Ident => node.cast().map(Self::Ident),
SyntaxKind::None => node.cast().map(Self::None),
SyntaxKind::Auto => node.cast().map(Self::Auto),
SyntaxKind::Bool => node.cast().map(Self::Bool),
SyntaxKind::Int => node.cast().map(Self::Int),
SyntaxKind::Float => node.cast().map(Self::Float),
SyntaxKind::Numeric => node.cast().map(Self::Numeric),
SyntaxKind::Str => node.cast().map(Self::Str),
SyntaxKind::CodeBlock => node.cast().map(Self::Code),
SyntaxKind::ContentBlock => node.cast().map(Self::Content),
SyntaxKind::Parenthesized => node.cast().map(Self::Parenthesized),
SyntaxKind::Array => node.cast().map(Self::Array),
SyntaxKind::Dict => node.cast().map(Self::Dict),
SyntaxKind::Unary => node.cast().map(Self::Unary),
SyntaxKind::Binary => node.cast().map(Self::Binary),
SyntaxKind::FieldAccess => node.cast().map(Self::FieldAccess),
SyntaxKind::FuncCall => node.cast().map(Self::FuncCall),
SyntaxKind::Closure => node.cast().map(Self::Closure),
SyntaxKind::LetBinding => node.cast().map(Self::Let),
SyntaxKind::DestructAssignment => node.cast().map(Self::DestructAssign),
SyntaxKind::SetRule => node.cast().map(Self::Set),
SyntaxKind::ShowRule => node.cast().map(Self::Show),
SyntaxKind::Contextual => node.cast().map(Self::Contextual),
SyntaxKind::Conditional => node.cast().map(Self::Conditional),
SyntaxKind::WhileLoop => node.cast().map(Self::While),
SyntaxKind::ForLoop => node.cast().map(Self::For),
SyntaxKind::ModuleImport => node.cast().map(Self::Import),
SyntaxKind::ModuleInclude => node.cast().map(Self::Include),
SyntaxKind::LoopBreak => node.cast().map(Self::Break),
SyntaxKind::LoopContinue => node.cast().map(Self::Continue),
SyntaxKind::FuncReturn => node.cast().map(Self::Return),
_ => Option::None,
}
}
fn to_untyped(self) -> &'a SyntaxNode {
match self {
Self::Text(v) => v.to_untyped(),
Self::Space(v) => v.to_untyped(),
Self::Linebreak(v) => v.to_untyped(),
Self::Parbreak(v) => v.to_untyped(),
Self::Escape(v) => v.to_untyped(),
Self::Shorthand(v) => v.to_untyped(),
Self::SmartQuote(v) => v.to_untyped(),
Self::Strong(v) => v.to_untyped(),
Self::Emph(v) => v.to_untyped(),
Self::Raw(v) => v.to_untyped(),
Self::Link(v) => v.to_untyped(),
Self::Label(v) => v.to_untyped(),
Self::Ref(v) => v.to_untyped(),
Self::Heading(v) => v.to_untyped(),
Self::List(v) => v.to_untyped(),
Self::Enum(v) => v.to_untyped(),
Self::Term(v) => v.to_untyped(),
Self::Equation(v) => v.to_untyped(),
Self::Math(v) => v.to_untyped(),
Self::MathIdent(v) => v.to_untyped(),
Self::MathShorthand(v) => v.to_untyped(),
Self::MathAlignPoint(v) => v.to_untyped(),
Self::MathDelimited(v) => v.to_untyped(),
Self::MathAttach(v) => v.to_untyped(),
Self::MathPrimes(v) => v.to_untyped(),
Self::MathFrac(v) => v.to_untyped(),
Self::MathRoot(v) => v.to_untyped(),
Self::Ident(v) => v.to_untyped(),
Self::None(v) => v.to_untyped(),
Self::Auto(v) => v.to_untyped(),
Self::Bool(v) => v.to_untyped(),
Self::Int(v) => v.to_untyped(),
Self::Float(v) => v.to_untyped(),
Self::Numeric(v) => v.to_untyped(),
Self::Str(v) => v.to_untyped(),
Self::Code(v) => v.to_untyped(),
Self::Content(v) => v.to_untyped(),
Self::Array(v) => v.to_untyped(),
Self::Dict(v) => v.to_untyped(),
Self::Parenthesized(v) => v.to_untyped(),
Self::Unary(v) => v.to_untyped(),
Self::Binary(v) => v.to_untyped(),
Self::FieldAccess(v) => v.to_untyped(),
Self::FuncCall(v) => v.to_untyped(),
Self::Closure(v) => v.to_untyped(),
Self::Let(v) => v.to_untyped(),
Self::DestructAssign(v) => v.to_untyped(),
Self::Set(v) => v.to_untyped(),
Self::Show(v) => v.to_untyped(),
Self::Contextual(v) => v.to_untyped(),
Self::Conditional(v) => v.to_untyped(),
Self::While(v) => v.to_untyped(),
Self::For(v) => v.to_untyped(),
Self::Import(v) => v.to_untyped(),
Self::Include(v) => v.to_untyped(),
Self::Break(v) => v.to_untyped(),
Self::Continue(v) => v.to_untyped(),
Self::Return(v) => v.to_untyped(),
}
}
}
impl Expr<'_> {
pub fn hash(self) -> bool {
matches!(
self,
Self::Ident(_)
| Self::None(_)
| Self::Auto(_)
| Self::Bool(_)
| Self::Int(_)
| Self::Float(_)
| Self::Numeric(_)
| Self::Str(_)
| Self::Code(_)
| Self::Content(_)
| Self::Array(_)
| Self::Dict(_)
| Self::Parenthesized(_)
| Self::FieldAccess(_)
| Self::FuncCall(_)
| Self::Let(_)
| Self::Set(_)
| Self::Show(_)
| Self::Contextual(_)
| Self::Conditional(_)
| Self::While(_)
| Self::For(_)
| Self::Import(_)
| Self::Include(_)
| Self::Break(_)
| Self::Continue(_)
| Self::Return(_)
)
}
pub fn is_literal(self) -> bool {
matches!(
self,
Self::None(_)
| Self::Auto(_)
| Self::Bool(_)
| Self::Int(_)
| Self::Float(_)
| Self::Numeric(_)
| Self::Str(_)
)
}
}
impl Default for Expr<'_> {
fn default() -> Self {
Expr::None(None::default())
}
}
node! {
Text
}
impl<'a> Text<'a> {
pub fn get(self) -> &'a EcoString {
self.0.text()
}
}
node! {
Space
}
node! {
Linebreak
}
node! {
Parbreak
}
node! {
Escape
}
impl Escape<'_> {
pub fn get(self) -> char {
let mut s = Scanner::new(self.0.text());
s.expect('\\');
if s.eat_if("u{") {
let hex = s.eat_while(char::is_ascii_hexdigit);
u32::from_str_radix(hex, 16)
.ok()
.and_then(std::char::from_u32)
.unwrap_or_default()
} else {
s.eat().unwrap_or_default()
}
}
}
node! {
Shorthand
}
impl Shorthand<'_> {
pub const LIST: &'static [(&'static str, char)] = &[
("...", '…'),
("~", '\u{00A0}'),
("-", '\u{2212}'), ("--", '\u{2013}'),
("---", '\u{2014}'),
("-?", '\u{00AD}'),
];
pub fn get(self) -> char {
let text = self.0.text();
Self::LIST
.iter()
.find(|&&(s, _)| s == text)
.map_or_else(char::default, |&(_, c)| c)
}
}
node! {
SmartQuote
}
impl SmartQuote<'_> {
pub fn double(self) -> bool {
self.0.text() == "\""
}
}
node! {
Strong
}
impl<'a> Strong<'a> {
pub fn body(self) -> Markup<'a> {
self.0.cast_first_match().unwrap_or_default()
}
}
node! {
Emph
}
impl<'a> Emph<'a> {
pub fn body(self) -> Markup<'a> {
self.0.cast_first_match().unwrap_or_default()
}
}
node! {
Raw
}
impl<'a> Raw<'a> {
pub fn lines(self) -> impl DoubleEndedIterator<Item = Text<'a>> {
self.0.children().filter_map(SyntaxNode::cast)
}
pub fn lang(self) -> Option<RawLang<'a>> {
let delim: RawDelim = self.0.cast_first_match()?;
if delim.0.len() < 3 {
return Option::None;
}
self.0.cast_first_match()
}
pub fn block(self) -> bool {
self.0
.cast_first_match()
.is_some_and(|delim: RawDelim| delim.0.len() >= 3)
&& self.0.children().any(|e| {
e.kind() == SyntaxKind::RawTrimmed && e.text().chars().any(is_newline)
})
}
}
node! {
RawLang
}
impl<'a> RawLang<'a> {
pub fn get(self) -> &'a EcoString {
self.0.text()
}
}
node! {
RawDelim
}
node! {
Link
}
impl<'a> Link<'a> {
pub fn get(self) -> &'a EcoString {
self.0.text()
}
}
node! {
Label
}
impl<'a> Label<'a> {
pub fn get(self) -> &'a str {
self.0.text().trim_start_matches('<').trim_end_matches('>')
}
}
node! {
Ref
}
impl<'a> Ref<'a> {
pub fn target(self) -> &'a str {
self.0
.children()
.find(|node| node.kind() == SyntaxKind::RefMarker)
.map(|node| node.text().trim_start_matches('@'))
.unwrap_or_default()
}
pub fn supplement(self) -> Option<ContentBlock<'a>> {
self.0.cast_last_match()
}
}
node! {
Heading
}
impl<'a> Heading<'a> {
pub fn body(self) -> Markup<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn depth(self) -> NonZeroUsize {
self.0
.children()
.find(|node| node.kind() == SyntaxKind::HeadingMarker)
.and_then(|node| node.len().try_into().ok())
.unwrap_or(NonZeroUsize::new(1).unwrap())
}
}
node! {
ListItem
}
impl<'a> ListItem<'a> {
pub fn body(self) -> Markup<'a> {
self.0.cast_first_match().unwrap_or_default()
}
}
node! {
EnumItem
}
impl<'a> EnumItem<'a> {
pub fn number(self) -> Option<usize> {
self.0.children().find_map(|node| match node.kind() {
SyntaxKind::EnumMarker => node.text().trim_end_matches('.').parse().ok(),
_ => Option::None,
})
}
pub fn body(self) -> Markup<'a> {
self.0.cast_first_match().unwrap_or_default()
}
}
node! {
TermItem
}
impl<'a> TermItem<'a> {
pub fn term(self) -> Markup<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn description(self) -> Markup<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
Equation
}
impl<'a> Equation<'a> {
pub fn body(self) -> Math<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn block(self) -> bool {
let is_space = |node: Option<&SyntaxNode>| {
node.map(SyntaxNode::kind) == Some(SyntaxKind::Space)
};
is_space(self.0.children().nth(1)) && is_space(self.0.children().nth_back(1))
}
}
node! {
Math
}
impl<'a> Math<'a> {
pub fn exprs(self) -> impl DoubleEndedIterator<Item = Expr<'a>> {
self.0.children().filter_map(Expr::cast_with_space)
}
}
node! {
MathIdent
}
impl<'a> MathIdent<'a> {
pub fn get(self) -> &'a EcoString {
self.0.text()
}
pub fn as_str(self) -> &'a str {
self.get()
}
}
impl Deref for MathIdent<'_> {
type Target = str;
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
node! {
MathShorthand
}
impl MathShorthand<'_> {
pub const LIST: &'static [(&'static str, char)] = &[
("...", '…'),
("-", '−'),
("'", '′'),
("*", '∗'),
("~", '∼'),
("!=", '≠'),
(":=", '≔'),
("::=", '⩴'),
("=:", '≕'),
("<<", '≪'),
("<<<", '⋘'),
(">>", '≫'),
(">>>", '⋙'),
("<=", '≤'),
(">=", '≥'),
("->", '→'),
("-->", '⟶'),
("|->", '↦'),
(">->", '↣'),
("->>", '↠'),
("<-", '←'),
("<--", '⟵'),
("<-<", '↢'),
("<<-", '↞'),
("<->", '↔'),
("<-->", '⟷'),
("~>", '⇝'),
("~~>", '⟿'),
("<~", '⇜'),
("<~~", '⬳'),
("=>", '⇒'),
("|=>", '⤇'),
("==>", '⟹'),
("<==", '⟸'),
("<=>", '⇔'),
("<==>", '⟺'),
("[|", '⟦'),
("|]", '⟧'),
("||", '‖'),
];
pub fn get(self) -> char {
let text = self.0.text();
Self::LIST
.iter()
.find(|&&(s, _)| s == text)
.map_or_else(char::default, |&(_, c)| c)
}
}
node! {
MathAlignPoint
}
node! {
MathDelimited
}
impl<'a> MathDelimited<'a> {
pub fn open(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn body(self) -> Math<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn close(self) -> Expr<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
MathAttach
}
impl<'a> MathAttach<'a> {
pub fn base(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn bottom(self) -> Option<Expr<'a>> {
self.0
.children()
.skip_while(|node| !matches!(node.kind(), SyntaxKind::Underscore))
.find_map(SyntaxNode::cast)
}
pub fn top(self) -> Option<Expr<'a>> {
self.0
.children()
.skip_while(|node| !matches!(node.kind(), SyntaxKind::Hat))
.find_map(SyntaxNode::cast)
}
pub fn primes(self) -> Option<MathPrimes<'a>> {
self.0
.children()
.skip_while(|node| node.cast::<Expr<'_>>().is_none())
.nth(1)
.and_then(|n| n.cast())
}
}
node! {
MathPrimes
}
impl MathPrimes<'_> {
pub fn count(self) -> usize {
self.0
.children()
.filter(|node| matches!(node.kind(), SyntaxKind::Prime))
.count()
}
}
node! {
MathFrac
}
impl<'a> MathFrac<'a> {
pub fn num(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn denom(self) -> Expr<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
MathRoot
}
impl<'a> MathRoot<'a> {
pub fn index(self) -> Option<usize> {
match self.0.children().next().map(|node| node.text().as_str()) {
Some("∜") => Some(4),
Some("∛") => Some(3),
Some("√") => Option::None,
_ => Option::None,
}
}
pub fn radicand(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
}
node! {
Ident
}
impl<'a> Ident<'a> {
pub fn get(self) -> &'a EcoString {
self.0.text()
}
pub fn as_str(self) -> &'a str {
self.get()
}
}
impl Deref for Ident<'_> {
type Target = str;
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
node! {
None
}
node! {
Auto
}
node! {
Bool
}
impl Bool<'_> {
pub fn get(self) -> bool {
self.0.text() == "true"
}
}
node! {
Int
}
impl Int<'_> {
pub fn get(self) -> i64 {
let text = self.0.text();
if let Some(rest) = text.strip_prefix("0x") {
i64::from_str_radix(rest, 16)
} else if let Some(rest) = text.strip_prefix("0o") {
i64::from_str_radix(rest, 8)
} else if let Some(rest) = text.strip_prefix("0b") {
i64::from_str_radix(rest, 2)
} else {
text.parse()
}
.unwrap_or_default()
}
}
node! {
Float
}
impl Float<'_> {
pub fn get(self) -> f64 {
self.0.text().parse().unwrap_or_default()
}
}
node! {
Numeric
}
impl Numeric<'_> {
pub fn get(self) -> (f64, Unit) {
let text = self.0.text();
let count = text
.chars()
.rev()
.take_while(|c| matches!(c, 'a'..='z' | '%'))
.count();
let split = text.len() - count;
let value = text[..split].parse().unwrap_or_default();
let unit = match &text[split..] {
"pt" => Unit::Pt,
"mm" => Unit::Mm,
"cm" => Unit::Cm,
"in" => Unit::In,
"deg" => Unit::Deg,
"rad" => Unit::Rad,
"em" => Unit::Em,
"fr" => Unit::Fr,
"%" => Unit::Percent,
_ => Unit::Percent,
};
(value, unit)
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum Unit {
Pt,
Mm,
Cm,
In,
Rad,
Deg,
Em,
Fr,
Percent,
}
node! {
Str
}
impl Str<'_> {
pub fn get(self) -> EcoString {
let text = self.0.text();
let unquoted = &text[1..text.len() - 1];
if !unquoted.contains('\\') {
return unquoted.into();
}
let mut out = EcoString::with_capacity(unquoted.len());
let mut s = Scanner::new(unquoted);
while let Some(c) = s.eat() {
if c != '\\' {
out.push(c);
continue;
}
let start = s.locate(-1);
match s.eat() {
Some('\\') => out.push('\\'),
Some('"') => out.push('"'),
Some('n') => out.push('\n'),
Some('r') => out.push('\r'),
Some('t') => out.push('\t'),
Some('u') if s.eat_if('{') => {
let sequence = s.eat_while(char::is_ascii_hexdigit);
s.eat_if('}');
match u32::from_str_radix(sequence, 16)
.ok()
.and_then(std::char::from_u32)
{
Some(c) => out.push(c),
Option::None => out.push_str(s.from(start)),
}
}
_ => out.push_str(s.from(start)),
}
}
out
}
}
node! {
CodeBlock
}
impl<'a> CodeBlock<'a> {
pub fn body(self) -> Code<'a> {
self.0.cast_first_match().unwrap_or_default()
}
}
node! {
Code
}
impl<'a> Code<'a> {
pub fn exprs(self) -> impl DoubleEndedIterator<Item = Expr<'a>> {
self.0.children().filter_map(SyntaxNode::cast)
}
}
node! {
ContentBlock
}
impl<'a> ContentBlock<'a> {
pub fn body(self) -> Markup<'a> {
self.0.cast_first_match().unwrap_or_default()
}
}
node! {
Parenthesized
}
impl<'a> Parenthesized<'a> {
pub fn expr(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn pattern(self) -> Pattern<'a> {
self.0.cast_first_match().unwrap_or_default()
}
}
node! {
Array
}
impl<'a> Array<'a> {
pub fn items(self) -> impl DoubleEndedIterator<Item = ArrayItem<'a>> {
self.0.children().filter_map(SyntaxNode::cast)
}
}
#[derive(Debug, Copy, Clone, Hash)]
pub enum ArrayItem<'a> {
Pos(Expr<'a>),
Spread(Spread<'a>),
}
impl<'a> AstNode<'a> for ArrayItem<'a> {
fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
match node.kind() {
SyntaxKind::Spread => node.cast().map(Self::Spread),
_ => node.cast().map(Self::Pos),
}
}
fn to_untyped(self) -> &'a SyntaxNode {
match self {
Self::Pos(v) => v.to_untyped(),
Self::Spread(v) => v.to_untyped(),
}
}
}
node! {
Dict
}
impl<'a> Dict<'a> {
pub fn items(self) -> impl DoubleEndedIterator<Item = DictItem<'a>> {
self.0.children().filter_map(SyntaxNode::cast)
}
}
#[derive(Debug, Copy, Clone, Hash)]
pub enum DictItem<'a> {
Named(Named<'a>),
Keyed(Keyed<'a>),
Spread(Spread<'a>),
}
impl<'a> AstNode<'a> for DictItem<'a> {
fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
match node.kind() {
SyntaxKind::Named => node.cast().map(Self::Named),
SyntaxKind::Keyed => node.cast().map(Self::Keyed),
SyntaxKind::Spread => node.cast().map(Self::Spread),
_ => Option::None,
}
}
fn to_untyped(self) -> &'a SyntaxNode {
match self {
Self::Named(v) => v.to_untyped(),
Self::Keyed(v) => v.to_untyped(),
Self::Spread(v) => v.to_untyped(),
}
}
}
node! {
Named
}
impl<'a> Named<'a> {
pub fn name(self) -> Ident<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn expr(self) -> Expr<'a> {
self.0.cast_last_match().unwrap_or_default()
}
pub fn pattern(self) -> Pattern<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
Keyed
}
impl<'a> Keyed<'a> {
pub fn key(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn expr(self) -> Expr<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
Spread
}
impl<'a> Spread<'a> {
pub fn expr(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn sink_ident(self) -> Option<Ident<'a>> {
self.0.cast_first_match()
}
pub fn sink_expr(self) -> Option<Expr<'a>> {
self.0.cast_first_match()
}
}
node! {
Unary
}
impl<'a> Unary<'a> {
pub fn op(self) -> UnOp {
self.0
.children()
.find_map(|node| UnOp::from_kind(node.kind()))
.unwrap_or(UnOp::Pos)
}
pub fn expr(self) -> Expr<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum UnOp {
Pos,
Neg,
Not,
}
impl UnOp {
pub fn from_kind(token: SyntaxKind) -> Option<Self> {
Some(match token {
SyntaxKind::Plus => Self::Pos,
SyntaxKind::Minus => Self::Neg,
SyntaxKind::Not => Self::Not,
_ => return Option::None,
})
}
pub fn precedence(self) -> usize {
match self {
Self::Pos | Self::Neg => 7,
Self::Not => 4,
}
}
pub fn as_str(self) -> &'static str {
match self {
Self::Pos => "+",
Self::Neg => "-",
Self::Not => "not",
}
}
}
node! {
Binary
}
impl<'a> Binary<'a> {
pub fn op(self) -> BinOp {
let mut not = false;
self.0
.children()
.find_map(|node| match node.kind() {
SyntaxKind::Not => {
not = true;
Option::None
}
SyntaxKind::In if not => Some(BinOp::NotIn),
_ => BinOp::from_kind(node.kind()),
})
.unwrap_or(BinOp::Add)
}
pub fn lhs(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn rhs(self) -> Expr<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum BinOp {
Add,
Sub,
Mul,
Div,
And,
Or,
Eq,
Neq,
Lt,
Leq,
Gt,
Geq,
Assign,
In,
NotIn,
AddAssign,
SubAssign,
MulAssign,
DivAssign,
}
impl BinOp {
pub fn from_kind(token: SyntaxKind) -> Option<Self> {
Some(match token {
SyntaxKind::Plus => Self::Add,
SyntaxKind::Minus => Self::Sub,
SyntaxKind::Star => Self::Mul,
SyntaxKind::Slash => Self::Div,
SyntaxKind::And => Self::And,
SyntaxKind::Or => Self::Or,
SyntaxKind::EqEq => Self::Eq,
SyntaxKind::ExclEq => Self::Neq,
SyntaxKind::Lt => Self::Lt,
SyntaxKind::LtEq => Self::Leq,
SyntaxKind::Gt => Self::Gt,
SyntaxKind::GtEq => Self::Geq,
SyntaxKind::Eq => Self::Assign,
SyntaxKind::In => Self::In,
SyntaxKind::PlusEq => Self::AddAssign,
SyntaxKind::HyphEq => Self::SubAssign,
SyntaxKind::StarEq => Self::MulAssign,
SyntaxKind::SlashEq => Self::DivAssign,
_ => return Option::None,
})
}
pub fn precedence(self) -> usize {
match self {
Self::Mul => 6,
Self::Div => 6,
Self::Add => 5,
Self::Sub => 5,
Self::Eq => 4,
Self::Neq => 4,
Self::Lt => 4,
Self::Leq => 4,
Self::Gt => 4,
Self::Geq => 4,
Self::In => 4,
Self::NotIn => 4,
Self::And => 3,
Self::Or => 2,
Self::Assign => 1,
Self::AddAssign => 1,
Self::SubAssign => 1,
Self::MulAssign => 1,
Self::DivAssign => 1,
}
}
pub fn assoc(self) -> Assoc {
match self {
Self::Add => Assoc::Left,
Self::Sub => Assoc::Left,
Self::Mul => Assoc::Left,
Self::Div => Assoc::Left,
Self::And => Assoc::Left,
Self::Or => Assoc::Left,
Self::Eq => Assoc::Left,
Self::Neq => Assoc::Left,
Self::Lt => Assoc::Left,
Self::Leq => Assoc::Left,
Self::Gt => Assoc::Left,
Self::Geq => Assoc::Left,
Self::In => Assoc::Left,
Self::NotIn => Assoc::Left,
Self::Assign => Assoc::Right,
Self::AddAssign => Assoc::Right,
Self::SubAssign => Assoc::Right,
Self::MulAssign => Assoc::Right,
Self::DivAssign => Assoc::Right,
}
}
pub fn as_str(self) -> &'static str {
match self {
Self::Add => "+",
Self::Sub => "-",
Self::Mul => "*",
Self::Div => "/",
Self::And => "and",
Self::Or => "or",
Self::Eq => "==",
Self::Neq => "!=",
Self::Lt => "<",
Self::Leq => "<=",
Self::Gt => ">",
Self::Geq => ">=",
Self::In => "in",
Self::NotIn => "not in",
Self::Assign => "=",
Self::AddAssign => "+=",
Self::SubAssign => "-=",
Self::MulAssign => "*=",
Self::DivAssign => "/=",
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum Assoc {
Left,
Right,
}
node! {
FieldAccess
}
impl<'a> FieldAccess<'a> {
pub fn target(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn field(self) -> Ident<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
FuncCall
}
impl<'a> FuncCall<'a> {
pub fn callee(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn args(self) -> Args<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
Args
}
impl<'a> Args<'a> {
pub fn items(self) -> impl DoubleEndedIterator<Item = Arg<'a>> {
self.0.children().filter_map(SyntaxNode::cast)
}
pub fn trailing_comma(self) -> bool {
self.0
.children()
.rev()
.skip(1)
.find(|n| !n.kind().is_trivia())
.is_some_and(|n| n.kind() == SyntaxKind::Comma)
}
}
#[derive(Debug, Copy, Clone, Hash)]
pub enum Arg<'a> {
Pos(Expr<'a>),
Named(Named<'a>),
Spread(Spread<'a>),
}
impl<'a> AstNode<'a> for Arg<'a> {
fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
match node.kind() {
SyntaxKind::Named => node.cast().map(Self::Named),
SyntaxKind::Spread => node.cast().map(Self::Spread),
_ => node.cast().map(Self::Pos),
}
}
fn to_untyped(self) -> &'a SyntaxNode {
match self {
Self::Pos(v) => v.to_untyped(),
Self::Named(v) => v.to_untyped(),
Self::Spread(v) => v.to_untyped(),
}
}
}
node! {
Closure
}
impl<'a> Closure<'a> {
pub fn name(self) -> Option<Ident<'a>> {
self.0.children().next()?.cast()
}
pub fn params(self) -> Params<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn body(self) -> Expr<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
Params
}
impl<'a> Params<'a> {
pub fn children(self) -> impl DoubleEndedIterator<Item = Param<'a>> {
self.0.children().filter_map(SyntaxNode::cast)
}
}
#[derive(Debug, Copy, Clone, Hash)]
pub enum Param<'a> {
Pos(Pattern<'a>),
Named(Named<'a>),
Spread(Spread<'a>),
}
impl<'a> AstNode<'a> for Param<'a> {
fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
match node.kind() {
SyntaxKind::Named => node.cast().map(Self::Named),
SyntaxKind::Spread => node.cast().map(Self::Spread),
_ => node.cast().map(Self::Pos),
}
}
fn to_untyped(self) -> &'a SyntaxNode {
match self {
Self::Pos(v) => v.to_untyped(),
Self::Named(v) => v.to_untyped(),
Self::Spread(v) => v.to_untyped(),
}
}
}
#[derive(Debug, Copy, Clone, Hash)]
pub enum Pattern<'a> {
Normal(Expr<'a>),
Placeholder(Underscore<'a>),
Parenthesized(Parenthesized<'a>),
Destructuring(Destructuring<'a>),
}
impl<'a> AstNode<'a> for Pattern<'a> {
fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
match node.kind() {
SyntaxKind::Underscore => node.cast().map(Self::Placeholder),
SyntaxKind::Parenthesized => node.cast().map(Self::Parenthesized),
SyntaxKind::Destructuring => node.cast().map(Self::Destructuring),
_ => node.cast().map(Self::Normal),
}
}
fn to_untyped(self) -> &'a SyntaxNode {
match self {
Self::Normal(v) => v.to_untyped(),
Self::Placeholder(v) => v.to_untyped(),
Self::Parenthesized(v) => v.to_untyped(),
Self::Destructuring(v) => v.to_untyped(),
}
}
}
impl<'a> Pattern<'a> {
pub fn bindings(self) -> Vec<Ident<'a>> {
match self {
Self::Normal(Expr::Ident(ident)) => vec![ident],
Self::Parenthesized(v) => v.pattern().bindings(),
Self::Destructuring(v) => v.bindings(),
_ => vec![],
}
}
}
impl Default for Pattern<'_> {
fn default() -> Self {
Self::Normal(Expr::default())
}
}
node! {
Underscore
}
node! {
Destructuring
}
impl<'a> Destructuring<'a> {
pub fn items(self) -> impl DoubleEndedIterator<Item = DestructuringItem<'a>> {
self.0.children().filter_map(SyntaxNode::cast)
}
pub fn bindings(self) -> Vec<Ident<'a>> {
self.items()
.flat_map(|binding| match binding {
DestructuringItem::Pattern(pattern) => pattern.bindings(),
DestructuringItem::Named(named) => named.pattern().bindings(),
DestructuringItem::Spread(spread) => {
spread.sink_ident().into_iter().collect()
}
})
.collect()
}
}
#[derive(Debug, Copy, Clone, Hash)]
pub enum DestructuringItem<'a> {
Pattern(Pattern<'a>),
Named(Named<'a>),
Spread(Spread<'a>),
}
impl<'a> AstNode<'a> for DestructuringItem<'a> {
fn from_untyped(node: &'a SyntaxNode) -> Option<Self> {
match node.kind() {
SyntaxKind::Named => node.cast().map(Self::Named),
SyntaxKind::Spread => node.cast().map(Self::Spread),
_ => node.cast().map(Self::Pattern),
}
}
fn to_untyped(self) -> &'a SyntaxNode {
match self {
Self::Pattern(v) => v.to_untyped(),
Self::Named(v) => v.to_untyped(),
Self::Spread(v) => v.to_untyped(),
}
}
}
node! {
LetBinding
}
#[derive(Debug)]
pub enum LetBindingKind<'a> {
Normal(Pattern<'a>),
Closure(Ident<'a>),
}
impl<'a> LetBindingKind<'a> {
pub fn bindings(self) -> Vec<Ident<'a>> {
match self {
LetBindingKind::Normal(pattern) => pattern.bindings(),
LetBindingKind::Closure(ident) => vec![ident],
}
}
}
impl<'a> LetBinding<'a> {
pub fn kind(self) -> LetBindingKind<'a> {
match self.0.cast_first_match::<Pattern>() {
Some(Pattern::Normal(Expr::Closure(closure))) => {
LetBindingKind::Closure(closure.name().unwrap_or_default())
}
pattern => LetBindingKind::Normal(pattern.unwrap_or_default()),
}
}
pub fn init(self) -> Option<Expr<'a>> {
match self.kind() {
LetBindingKind::Normal(Pattern::Normal(_) | Pattern::Parenthesized(_)) => {
self.0.children().filter_map(SyntaxNode::cast).nth(1)
}
LetBindingKind::Normal(_) => self.0.cast_first_match(),
LetBindingKind::Closure(_) => self.0.cast_first_match(),
}
}
}
node! {
DestructAssignment
}
impl<'a> DestructAssignment<'a> {
pub fn pattern(self) -> Pattern<'a> {
self.0.cast_first_match::<Pattern>().unwrap_or_default()
}
pub fn value(self) -> Expr<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
SetRule
}
impl<'a> SetRule<'a> {
pub fn target(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn args(self) -> Args<'a> {
self.0.cast_last_match().unwrap_or_default()
}
pub fn condition(self) -> Option<Expr<'a>> {
self.0
.children()
.skip_while(|child| child.kind() != SyntaxKind::If)
.find_map(SyntaxNode::cast)
}
}
node! {
ShowRule
}
impl<'a> ShowRule<'a> {
pub fn selector(self) -> Option<Expr<'a>> {
self.0
.children()
.rev()
.skip_while(|child| child.kind() != SyntaxKind::Colon)
.find_map(SyntaxNode::cast)
}
pub fn transform(self) -> Expr<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
Contextual
}
impl<'a> Contextual<'a> {
pub fn body(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
}
node! {
Conditional
}
impl<'a> Conditional<'a> {
pub fn condition(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn if_body(self) -> Expr<'a> {
self.0
.children()
.filter_map(SyntaxNode::cast)
.nth(1)
.unwrap_or_default()
}
pub fn else_body(self) -> Option<Expr<'a>> {
self.0.children().filter_map(SyntaxNode::cast).nth(2)
}
}
node! {
WhileLoop
}
impl<'a> WhileLoop<'a> {
pub fn condition(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn body(self) -> Expr<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
ForLoop
}
impl<'a> ForLoop<'a> {
pub fn pattern(self) -> Pattern<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn iterable(self) -> Expr<'a> {
self.0
.children()
.skip_while(|&c| c.kind() != SyntaxKind::In)
.find_map(SyntaxNode::cast)
.unwrap_or_default()
}
pub fn body(self) -> Expr<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
ModuleImport
}
impl<'a> ModuleImport<'a> {
pub fn source(self) -> Expr<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn imports(self) -> Option<Imports<'a>> {
self.0.children().find_map(|node| match node.kind() {
SyntaxKind::Star => Some(Imports::Wildcard),
SyntaxKind::ImportItems => node.cast().map(Imports::Items),
_ => Option::None,
})
}
pub fn new_name(self) -> Option<Ident<'a>> {
self.0
.children()
.skip_while(|child| child.kind() != SyntaxKind::As)
.find_map(SyntaxNode::cast)
}
}
#[derive(Debug, Copy, Clone, Hash)]
pub enum Imports<'a> {
Wildcard,
Items(ImportItems<'a>),
}
node! {
ImportItems
}
impl<'a> ImportItems<'a> {
pub fn iter(self) -> impl DoubleEndedIterator<Item = ImportItem<'a>> {
self.0.children().filter_map(|child| match child.kind() {
SyntaxKind::RenamedImportItem => child.cast().map(ImportItem::Renamed),
SyntaxKind::ImportItemPath => child.cast().map(ImportItem::Simple),
_ => Option::None,
})
}
}
node! {
ImportItemPath
}
impl<'a> ImportItemPath<'a> {
pub fn iter(self) -> impl DoubleEndedIterator<Item = Ident<'a>> {
self.0.children().filter_map(SyntaxNode::cast)
}
pub fn name(self) -> Ident<'a> {
self.iter().last().unwrap_or_default()
}
}
#[derive(Debug, Copy, Clone, Hash)]
pub enum ImportItem<'a> {
Simple(ImportItemPath<'a>),
Renamed(RenamedImportItem<'a>),
}
impl<'a> ImportItem<'a> {
pub fn path(self) -> ImportItemPath<'a> {
match self {
Self::Simple(path) => path,
Self::Renamed(renamed_item) => renamed_item.path(),
}
}
pub fn original_name(self) -> Ident<'a> {
match self {
Self::Simple(path) => path.name(),
Self::Renamed(renamed_item) => renamed_item.original_name(),
}
}
pub fn bound_name(self) -> Ident<'a> {
match self {
Self::Simple(path) => path.name(),
Self::Renamed(renamed_item) => renamed_item.new_name(),
}
}
}
node! {
RenamedImportItem
}
impl<'a> RenamedImportItem<'a> {
pub fn path(self) -> ImportItemPath<'a> {
self.0.cast_first_match().unwrap_or_default()
}
pub fn original_name(self) -> Ident<'a> {
self.path().name()
}
pub fn new_name(self) -> Ident<'a> {
self.0
.children()
.filter_map(SyntaxNode::cast)
.last()
.unwrap_or_default()
}
}
node! {
ModuleInclude
}
impl<'a> ModuleInclude<'a> {
pub fn source(self) -> Expr<'a> {
self.0.cast_last_match().unwrap_or_default()
}
}
node! {
LoopBreak
}
node! {
LoopContinue
}
node! {
FuncReturn
}
impl<'a> FuncReturn<'a> {
pub fn body(self) -> Option<Expr<'a>> {
self.0.cast_last_match()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expr_default() {
assert!(Expr::default().to_untyped().cast::<Expr>().is_some());
}
}