use serde::ser::{Serialize, SerializeStruct, Serializer};
use super::{ConcreteSyntaxTree, VB6Language};
use crate::parsers::SyntaxKind;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CstNode {
kind: SyntaxKind,
text: String,
is_token: bool,
children: Vec<CstNode>,
}
impl Serialize for CstNode {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("CstNode", 2)?;
state.serialize_field("kind", &self.kind)?;
if self.is_token {
state.serialize_field("text", &self.text)?;
} else {
state.serialize_field("children", &self.children)?;
}
state.end()
}
}
impl CstNode {
pub(crate) fn new(
kind: SyntaxKind,
text: String,
is_token: bool,
children: Vec<CstNode>,
) -> Self {
Self {
kind,
text,
is_token,
children,
}
}
}
impl CstNode {
#[inline]
#[must_use]
pub fn kind(&self) -> SyntaxKind {
self.kind
}
#[inline]
#[must_use]
pub fn text(&self) -> &str {
&self.text
}
#[inline]
#[must_use]
pub fn is_token(&self) -> bool {
self.is_token
}
#[inline]
#[must_use]
pub fn children(&self) -> &[CstNode] {
&self.children
}
#[must_use]
pub fn child_count(&self) -> usize {
self.children.len()
}
#[must_use]
pub fn first_child(&self) -> Option<&CstNode> {
self.children.first()
}
#[must_use]
pub fn last_child(&self) -> Option<&CstNode> {
self.children.last()
}
#[must_use]
pub fn child_at(&self, index: usize) -> Option<&CstNode> {
self.children.get(index)
}
pub fn children_by_kind(&self, kind: SyntaxKind) -> impl Iterator<Item = &CstNode> {
self.children()
.iter()
.filter(move |child| child.kind() == kind)
}
#[must_use]
pub fn first_child_by_kind(&self, kind: SyntaxKind) -> Option<&CstNode> {
self.children().iter().find(|child| child.kind() == kind)
}
#[must_use]
pub fn contains_kind(&self, kind: SyntaxKind) -> bool {
self.children().iter().any(|child| child.kind() == kind)
}
#[must_use]
pub fn find(&self, kind: SyntaxKind) -> Option<&CstNode> {
if self.kind() == kind {
return Some(self);
}
for child in self.children() {
if let Some(found) = child.find(kind) {
return Some(found);
}
}
None
}
#[must_use]
pub fn find_all(&self, kind: SyntaxKind) -> Vec<&CstNode> {
let mut results = Vec::new();
self.find_all_recursive(kind, &mut results);
results
}
fn find_all_recursive<'a>(&'a self, kind: SyntaxKind, results: &mut Vec<&'a CstNode>) {
if self.kind() == kind {
results.push(self);
}
for child in self.children() {
child.find_all_recursive(kind, results);
}
}
pub fn non_token_children(&self) -> impl Iterator<Item = &CstNode> {
self.children().iter().filter(|child| !child.is_token())
}
pub fn token_children(&self) -> impl Iterator<Item = &CstNode> {
self.children().iter().filter(|child| child.is_token())
}
#[must_use]
pub fn first_non_whitespace_child(&self) -> Option<&CstNode> {
self.children()
.iter()
.find(|child| child.kind() != SyntaxKind::Whitespace)
}
pub fn significant_children(&self) -> impl Iterator<Item = &CstNode> {
self.children().iter().filter(|child| {
child.kind() != SyntaxKind::Whitespace && child.kind() != SyntaxKind::Newline
})
}
#[must_use]
pub fn find_if<F>(&self, predicate: F) -> Option<&CstNode>
where
F: Fn(&CstNode) -> bool,
{
self.find_if_internal(&predicate)
}
fn find_if_internal(&self, predicate: &dyn Fn(&CstNode) -> bool) -> Option<&CstNode> {
if predicate(self) {
return Some(self);
}
for child in self.children() {
if let Some(found) = child.find_if_internal(predicate) {
return Some(found);
}
}
None
}
#[must_use]
pub fn find_all_if<F>(&self, predicate: F) -> Vec<&CstNode>
where
F: Fn(&CstNode) -> bool,
{
let mut results = Vec::new();
self.find_all_if_internal(&predicate, &mut results);
results
}
fn find_all_if_internal<'a>(
&'a self,
predicate: &dyn Fn(&CstNode) -> bool,
results: &mut Vec<&'a CstNode>,
) {
if predicate(self) {
results.push(self);
}
for child in self.children() {
child.find_all_if_internal(predicate, results);
}
}
#[must_use]
pub fn descendants(&self) -> DepthFirstIter<'_> {
DepthFirstIter { stack: vec![self] }
}
#[must_use]
pub fn depth_first_iter(&self) -> DepthFirstIter<'_> {
self.descendants()
}
#[must_use]
pub fn is_whitespace(&self) -> bool {
self.kind() == SyntaxKind::Whitespace
}
#[must_use]
pub fn is_newline(&self) -> bool {
self.kind() == SyntaxKind::Newline
}
#[must_use]
pub fn is_comment(&self) -> bool {
matches!(
self.kind(),
SyntaxKind::EndOfLineComment | SyntaxKind::RemComment
)
}
#[must_use]
pub fn is_significant(&self) -> bool {
!self.is_trivia()
}
#[must_use]
pub fn is_trivia(&self) -> bool {
self.is_whitespace() || self.is_newline() || self.is_comment()
}
}
pub struct DepthFirstIter<'a> {
stack: Vec<&'a CstNode>,
}
impl<'a> Iterator for DepthFirstIter<'a> {
type Item = &'a CstNode;
fn next(&mut self) -> Option<Self::Item> {
let node = self.stack.pop()?;
for child in node.children().iter().rev() {
self.stack.push(child);
}
Some(node)
}
}
pub struct DepthFirstIterOwned {
stack: Vec<CstNode>,
}
impl Iterator for DepthFirstIterOwned {
type Item = CstNode;
fn next(&mut self) -> Option<Self::Item> {
let node = self.stack.pop()?;
for child in node.children().iter().rev() {
self.stack.push(child.clone());
}
Some(node)
}
}
impl ConcreteSyntaxTree {
#[must_use]
pub fn debug_tree(&self) -> String {
let syntax_node = rowan::SyntaxNode::<VB6Language>::new_root(self.root.clone());
format!("{syntax_node:#?}")
}
#[must_use]
pub fn text(&self) -> String {
let syntax_node = rowan::SyntaxNode::<VB6Language>::new_root(self.root.clone());
syntax_node.text().to_string()
}
#[must_use]
pub fn child_count(&self) -> usize {
self.root.children().count()
}
#[must_use]
pub fn children(&self) -> Vec<CstNode> {
let syntax_node = rowan::SyntaxNode::<VB6Language>::new_root(self.root.clone());
syntax_node
.children_with_tokens()
.map(Self::build_cst_node)
.collect()
}
fn build_cst_node(
node_or_token: rowan::NodeOrToken<
rowan::SyntaxNode<VB6Language>,
rowan::SyntaxToken<VB6Language>,
>,
) -> CstNode {
match node_or_token {
rowan::NodeOrToken::Node(node) => {
let children = node
.children_with_tokens()
.map(Self::build_cst_node)
.collect();
CstNode {
kind: node.kind(),
text: node.text().to_string(),
is_token: false,
children,
}
}
rowan::NodeOrToken::Token(token) => CstNode {
kind: token.kind(),
text: token.text().to_string(),
is_token: true,
children: Vec::new(),
},
}
}
pub fn children_by_kind(&self, kind: SyntaxKind) -> impl Iterator<Item = CstNode> {
self.children()
.into_iter()
.filter(move |child| child.kind() == kind)
}
#[must_use]
pub fn first_child_by_kind(&self, kind: SyntaxKind) -> Option<CstNode> {
self.children()
.into_iter()
.find(|child| child.kind() == kind)
}
#[must_use]
pub fn contains_kind(&self, kind: SyntaxKind) -> bool {
self.children().iter().any(|child| child.kind() == kind)
}
#[must_use]
pub fn first_child(&self) -> Option<CstNode> {
self.children().into_iter().next()
}
#[must_use]
pub fn last_child(&self) -> Option<CstNode> {
self.children().into_iter().last()
}
#[must_use]
pub fn child_at(&self, index: usize) -> Option<CstNode> {
self.children().into_iter().nth(index)
}
#[must_use]
pub fn find(&self, kind: SyntaxKind) -> Option<CstNode> {
let root_node = self.to_root_node();
root_node.find(kind).cloned()
}
#[must_use]
pub fn find_all(&self, kind: SyntaxKind) -> Vec<CstNode> {
let root_node = self.to_root_node();
root_node.find_all(kind).into_iter().cloned().collect()
}
pub fn non_token_children(&self) -> impl Iterator<Item = CstNode> {
self.children()
.into_iter()
.filter(|child| !child.is_token())
}
pub fn token_children(&self) -> impl Iterator<Item = CstNode> {
self.children().into_iter().filter(CstNode::is_token)
}
#[must_use]
pub fn first_non_whitespace_child(&self) -> Option<CstNode> {
self.children()
.into_iter()
.find(|child| child.kind() != SyntaxKind::Whitespace)
}
pub fn significant_children(&self) -> impl Iterator<Item = CstNode> {
self.children().into_iter().filter(|child| {
child.kind() != SyntaxKind::Whitespace && child.kind() != SyntaxKind::Newline
})
}
#[must_use]
pub fn find_if<F>(&self, predicate: F) -> Option<CstNode>
where
F: Fn(&CstNode) -> bool,
{
let root_node = self.to_root_node();
root_node.find_if(predicate).cloned()
}
#[must_use]
pub fn find_all_if<F>(&self, predicate: F) -> Vec<CstNode>
where
F: Fn(&CstNode) -> bool,
{
let root_node = self.to_root_node();
root_node
.find_all_if(predicate)
.into_iter()
.cloned()
.collect()
}
#[must_use]
pub fn descendants(&self) -> DepthFirstIterOwned {
let root = self.to_root_node();
DepthFirstIterOwned { stack: vec![root] }
}
#[must_use]
pub fn depth_first_iter(&self) -> DepthFirstIterOwned {
self.descendants()
}
}
#[cfg(test)]
mod tests {
use crate::parsers::{ConcreteSyntaxTree, SyntaxKind};
#[test]
fn navigation_children() {
let source = "Attribute VB_Name\nSub Test()\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let children = cst.children();
assert_eq!(children.len(), 2); assert_eq!(children[0].kind(), SyntaxKind::AttributeStatement);
assert_eq!(children[1].kind(), SyntaxKind::SubStatement);
assert!(!children[0].is_token());
assert!(!children[1].is_token());
}
#[test]
fn navigation_children_by_kind() {
let source = "Dim x\nDim y\nSub Test()\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let dim_statements: Vec<_> = cst.children_by_kind(SyntaxKind::DimStatement).collect();
assert_eq!(dim_statements.len(), 2);
let sub_statements: Vec<_> = cst.children_by_kind(SyntaxKind::SubStatement).collect();
assert_eq!(sub_statements.len(), 1);
assert!(cst.first_child_by_kind(SyntaxKind::DimStatement).is_some());
assert!(cst
.first_child_by_kind(SyntaxKind::FunctionStatement)
.is_none());
}
#[test]
fn navigation_contains_kind() {
let source = "Sub Test()\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
assert!(cst.contains_kind(SyntaxKind::SubStatement));
assert!(!cst.contains_kind(SyntaxKind::FunctionStatement));
assert!(!cst.contains_kind(SyntaxKind::DimStatement));
}
#[test]
fn navigation_first_and_last_child() {
let source = "Attribute VB_Name\nDim x\nSub Test()\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let first = cst.first_child().expect("Expected at least one child");
assert_eq!(first.kind(), SyntaxKind::AttributeStatement);
assert_eq!(first.text(), "Attribute VB_Name\n");
let last = cst.last_child().expect("Expected at least one child");
assert_eq!(last.kind(), SyntaxKind::SubStatement);
}
#[test]
fn navigation_child_at() {
let source = "Attribute VB_Name\nDim x\nSub Test()\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let first = cst.child_at(0).expect("Expected child at index 0");
assert_eq!(first.kind(), SyntaxKind::AttributeStatement);
let second = cst.child_at(1).expect("Expected child at index 1");
assert_eq!(second.kind(), SyntaxKind::DimStatement);
let third = cst.child_at(2).expect("Expected child at index 2");
assert_eq!(third.kind(), SyntaxKind::SubStatement);
assert!(cst.child_at(4).is_none());
}
#[test]
fn navigation_empty_tree() {
let source = "";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
assert_eq!(cst.children().len(), 0);
assert!(cst.first_child().is_none());
assert!(cst.last_child().is_none());
assert!(cst.child_at(0).is_none());
assert!(!cst.contains_kind(SyntaxKind::SubStatement));
}
#[test]
fn navigation_with_comments_and_whitespace() {
let source = "' Comment\n\nSub Test()\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let children = cst.children();
assert_eq!(children.len(), 4);
assert_eq!(children[0].kind(), SyntaxKind::EndOfLineComment);
assert!(children[0].is_token());
assert_eq!(children[1].kind(), SyntaxKind::Newline);
assert!(children[1].is_token());
assert_eq!(children[2].kind(), SyntaxKind::Newline);
assert!(children[2].is_token());
assert_eq!(children[3].kind(), SyntaxKind::SubStatement);
assert!(!children[3].is_token());
}
#[test]
fn cst_node_basic_navigation() {
let source = "Sub Test()\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let root = cst.to_serializable().root;
assert_eq!(root.child_count(), 1);
assert!(root.first_child().is_some());
assert!(root.last_child().is_some());
assert!(root.child_at(0).is_some());
assert!(root.child_at(10).is_none());
let first = root.first_child().expect("Expected first child");
assert_eq!(first.kind(), SyntaxKind::SubStatement);
}
#[test]
fn cst_node_filter_by_kind() {
let source = "Dim x\nDim y\nSub Test()\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let root = cst.to_serializable().root;
let dim_stmts: Vec<_> = root.children_by_kind(SyntaxKind::DimStatement).collect();
assert_eq!(dim_stmts.len(), 2);
assert!(root.first_child_by_kind(SyntaxKind::DimStatement).is_some());
assert!(root.contains_kind(SyntaxKind::SubStatement));
assert!(!root.contains_kind(SyntaxKind::FunctionStatement));
}
#[test]
fn cst_node_recursive_find() {
let source = "Sub Test()\nDim x As Integer\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let root = cst.to_serializable().root;
let dim = root.find(SyntaxKind::DimStatement);
assert!(dim.is_some());
assert_eq!(dim.unwrap().kind, SyntaxKind::DimStatement);
let identifiers = root.find_all(SyntaxKind::Identifier);
assert!(identifiers.len() >= 2); }
#[test]
fn cst_node_token_filtering() {
let source = "Sub Test()\n Dim x\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let root = cst.to_serializable().root;
let non_tokens: Vec<_> = root.non_token_children().collect();
let _tokens: Vec<_> = root.token_children().collect();
assert!(!non_tokens.is_empty());
let first_non_ws = root.first_non_whitespace_child();
assert!(first_non_ws.is_some());
assert_ne!(first_non_ws.unwrap().kind, SyntaxKind::Whitespace);
let significant: Vec<_> = root.significant_children().collect();
assert!(significant
.iter()
.all(|n| { n.kind != SyntaxKind::Whitespace && n.kind != SyntaxKind::Newline }));
}
#[test]
fn concrete_syntax_tree_recursive_find() {
let source = "Sub Test()\nDim x As Integer\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let dim = cst.find(SyntaxKind::DimStatement);
assert!(dim.is_some());
assert_eq!(dim.unwrap().kind, SyntaxKind::DimStatement);
let identifiers = cst.find_all(SyntaxKind::Identifier);
assert!(identifiers.len() >= 2);
let dim_direct: Vec<_> = cst.children_by_kind(SyntaxKind::DimStatement).collect();
assert_eq!(dim_direct.len(), 0); }
#[test]
fn concrete_syntax_tree_token_filtering() {
let source = "Sub Test()\n Dim x\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let non_tokens: Vec<_> = cst.non_token_children().collect();
assert!(!non_tokens.is_empty());
assert!(non_tokens.iter().all(|n| !n.is_token));
let source_with_leading_ws = " \nSub Test()\nEnd Sub\n";
let (cst2_opt, _failures) =
ConcreteSyntaxTree::from_text("test.bas", source_with_leading_ws).unpack();
let cst2 = cst2_opt.expect("Failed to parse source");
let first_non_ws = cst2.first_non_whitespace_child();
if let Some(node) = first_non_ws {
assert_ne!(node.kind(), SyntaxKind::Whitespace);
}
let significant: Vec<_> = cst.significant_children().collect();
assert!(significant
.iter()
.all(|n| { n.kind != SyntaxKind::Whitespace && n.kind != SyntaxKind::Newline }));
}
#[test]
fn cst_node_predicate_search() {
let source = "Sub Test()\nDim x As Integer\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let root = cst.to_serializable().root;
let first_non_token = root.find_if(|n| !n.is_token);
assert!(first_non_token.is_some());
assert!(!first_non_token.unwrap().is_token);
let keywords = root.find_all_if(|n| {
matches!(
n.kind,
SyntaxKind::SubKeyword | SyntaxKind::DimKeyword | SyntaxKind::AsKeyword
)
});
assert!(keywords.len() >= 3); }
#[test]
fn concrete_syntax_tree_predicate_search() {
let source = "Sub Test()\nDim x As Integer\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let first_non_token = cst.find_if(|n| !n.is_token);
assert!(first_non_token.is_some());
assert!(!first_non_token.unwrap().is_token);
let keywords = cst.find_all_if(|n| {
matches!(
n.kind,
SyntaxKind::SubKeyword | SyntaxKind::DimKeyword | SyntaxKind::AsKeyword
)
});
assert!(keywords.len() >= 3);
let complex_nodes = cst.find_all_if(|n| !n.is_token && n.children.len() > 2);
assert!(!complex_nodes.is_empty());
}
#[test]
fn cst_node_convenience_checkers() {
let source = "' Comment\nSub Test()\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let root = cst.to_serializable().root;
let comment = root.find(SyntaxKind::EndOfLineComment);
assert!(comment.is_some());
let comment = comment.expect("Expected to find comment node");
assert!(comment.is_comment());
assert!(comment.is_trivia());
assert!(!comment.is_significant());
let sub_stmt = root.find(SyntaxKind::SubStatement);
assert!(sub_stmt.is_some());
let sub_stmt = sub_stmt.expect("Expected to find SubStatement node");
assert!(sub_stmt.is_significant());
assert!(!sub_stmt.is_trivia());
assert!(!sub_stmt.is_whitespace());
assert!(!sub_stmt.is_newline());
assert!(!sub_stmt.is_comment());
}
#[test]
fn cst_node_iterator_traversal() {
let source = "Sub Test()\nDim x\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let root = cst.to_serializable().root;
let all_nodes: Vec<_> = root.descendants().collect();
assert!(!all_nodes.is_empty());
assert_eq!(all_nodes[0].kind(), SyntaxKind::Root);
let identifier_count = root
.descendants()
.filter(|n| n.kind == SyntaxKind::Identifier)
.count();
assert!(identifier_count >= 2);
let count_via_dfs = root.depth_first_iter().count();
assert_eq!(count_via_dfs, all_nodes.len());
}
#[test]
fn concrete_syntax_tree_iterator_traversal() {
use crate::parsers::{ConcreteSyntaxTree, CstNode, SyntaxKind};
let source = "Sub Test()\nDim x\nEnd Sub\n";
let (cst_opt, _failures) = ConcreteSyntaxTree::from_text("test.bas", source).unpack();
let cst = cst_opt.expect("Failed to parse source");
let all_nodes: Vec<_> = cst.descendants().collect();
assert!(!all_nodes.is_empty());
let identifier_count = cst
.descendants()
.filter(|n| n.kind == SyntaxKind::Identifier)
.count();
assert!(identifier_count >= 2);
let count_via_dfs = cst.depth_first_iter().count();
assert_eq!(count_via_dfs, all_nodes.len());
let non_trivia_count = cst.descendants().filter(CstNode::is_significant).count();
assert!(non_trivia_count > 0);
}
}