use super::span::{Span, TokenRange};
use super::tokens::Token;
use super::{lex, reconstruct};
use crate::v1::ast::AstNode;
use crate::v1::parser::parse as v1_parse;
#[derive(Debug, Clone)]
pub struct TreeNode {
pub label: &'static str,
pub scalars: String,
pub pos: usize,
pub range: TokenRange,
pub children: Vec<TreeNode>,
}
impl TreeNode {
pub fn tokens<'a>(&self, tree: &'a Tree) -> &'a [Token] {
&tree.tokens[self.range.first..self.range.end]
}
pub fn unparse_lossless(&self, tree: &Tree) -> String {
reconstruct(self.tokens(tree))
}
pub fn source_span(&self, tree: &Tree) -> Option<Span> {
let toks = self.tokens(tree);
match (toks.first(), toks.last()) {
(Some(f), Some(l)) => Some(f.full_span().join(l.full_span())),
_ => None,
}
}
pub fn walk(&self, visitor: &mut impl FnMut(&TreeNode)) {
visitor(self);
for child in &self.children {
child.walk(visitor);
}
}
pub fn walk_with_ancestors(&self, visitor: &mut impl FnMut(&TreeNode, &[&TreeNode])) {
fn go<'a>(
node: &'a TreeNode,
stack: &mut Vec<&'a TreeNode>,
visitor: &mut impl FnMut(&TreeNode, &[&TreeNode]),
) {
visitor(node, stack);
stack.push(node);
for child in &node.children {
go(child, stack, visitor);
}
stack.pop();
}
go(self, &mut Vec::new(), visitor);
}
}
#[derive(Debug, Clone)]
pub struct Tree {
pub tokens: Vec<Token>,
pub root: TreeNode,
pub ast: AstNode,
pub errors: Vec<String>,
}
impl Tree {
pub fn walk_zipped<'a, F>(&'a self, visitor: &mut F)
where
F: FnMut(&'a AstNode, &'a TreeNode, &[(&'a AstNode, &'a TreeNode)]),
{
fn go<'a, F>(
ast: &'a AstNode,
node: &'a TreeNode,
stack: &mut Vec<(&'a AstNode, &'a TreeNode)>,
visitor: &mut F,
) where
F: FnMut(&'a AstNode, &'a TreeNode, &[(&'a AstNode, &'a TreeNode)]),
{
visitor(ast, node, stack);
let kids = ast.safe_children();
debug_assert_eq!(
kids.len(),
node.children.len(),
"tree mirror out of sync under {}",
node.label
);
stack.push((ast, node));
for (ast_child, node_child) in kids.into_iter().zip(&node.children) {
go(ast_child, node_child, stack, visitor);
}
stack.pop();
}
go(&self.ast, &self.root, &mut Vec::new(), visitor);
}
pub fn unparse_lossless(&self) -> String {
reconstruct(&self.tokens)
}
}
pub fn parse_with_tokens(src: &str) -> Tree {
let tokens = lex(src).tokens;
let (ast, errors) = v1_parse(src);
let starts: Vec<usize> = tokens.iter().map(|t| t.span.start).collect();
let builder = Builder {
tokens: &tokens,
starts: &starts,
};
let root_ast = AstNode::ScriptBlock(ast);
let mut root = builder.build(&root_ast);
builder.assign_ranges(&mut root, 0, tokens.len());
Tree {
tokens,
root,
ast: root_ast,
errors,
}
}
struct Builder<'a> {
tokens: &'a [Token],
starts: &'a [usize],
}
impl Builder<'_> {
fn build(&self, node: &AstNode) -> TreeNode {
use crate::v1::ast::NodeInfo;
TreeNode {
label: node.label(),
scalars: node.scalars(),
pos: node.loc().pos,
range: TokenRange { first: 0, end: 0 },
children: node.safe_children().iter().map(|c| self.build(c)).collect(),
}
}
fn token_at_or_after(&self, pos: usize) -> Option<usize> {
match self.starts.binary_search(&pos) {
Ok(i) => Some(i),
Err(i) => (i < self.starts.len()).then_some(i),
}
}
fn inside_a_token(&self, pos: usize) -> bool {
self.tokens
.iter()
.any(|t| t.span.start < pos && pos < t.span.end)
}
fn assign_ranges(&self, node: &mut TreeNode, lo: usize, hi: usize) {
if self.inside_a_token(node.pos) {
let anchor = self.containing_token(node.pos).unwrap_or(lo).clamp(lo, hi);
node.range = TokenRange {
first: anchor,
end: anchor,
};
for child in &mut node.children {
self.assign_ranges(child, anchor, anchor);
}
return;
}
let first = match self.token_at_or_after(node.pos) {
Some(i) if i >= lo && i < hi => i,
_ => {
node.range = TokenRange { first: lo, end: lo };
for child in &mut node.children {
self.assign_ranges(child, lo, lo);
}
return;
}
};
let child_starts: Vec<usize> = node
.children
.iter()
.map(|c| self.token_at_or_after(c.pos).unwrap_or(first).max(first))
.collect();
let n = node.children.len();
for idx in 0..n {
let child_lo = child_starts[idx];
let child_hi = child_starts[idx + 1..]
.iter()
.copied()
.find(|&s| s > child_lo)
.unwrap_or(hi);
self.assign_ranges(
&mut node.children[idx],
child_lo,
child_hi.max(child_lo + 1),
);
}
let child_end = node.children.iter().map(|c| c.range.end).max();
let end = child_end.unwrap_or(first + 1).max(first + 1).max(
self.own_trailing_end(first, hi),
);
node.range = TokenRange {
first,
end: end.min(hi).max(first + 1),
};
}
fn containing_token(&self, pos: usize) -> Option<usize> {
self.tokens
.iter()
.position(|t| t.span.start <= pos && pos < t.span.end)
}
fn own_trailing_end(&self, first: usize, hi: usize) -> usize {
hi.max(first + 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn roundtrips(src: &str) {
let tree = parse_with_tokens(src);
assert_eq!(tree.unparse_lossless(), src, "tree round trip for {src:?}");
let toks = tree.tokens.len();
tree.root.walk(&mut |n| {
assert!(
n.range.first <= n.range.end,
"inverted range on {}",
n.label
);
assert!(n.range.end <= toks, "range past end on {}", n.label);
});
}
#[test]
fn root_roundtrips_across_corpus() {
for src in [
"Get-ChildItem -Path C:\\tmp | Where-Object { $_.Length -gt 1kb }\n",
"$x = 1 + 2 * 3\n",
"function Get-Thing {\n param([string]$Name)\n process { $Name }\n}\n",
"if ($a -eq 1) { 'one' } elseif ($a -eq 2) { 'two' } else { 'other' }\n",
"@{ a = 1; b = @(2, 3) } # a hashtable\n",
"try { risky } catch [System.Exception] { recover } finally { cleanup }\n",
"\"interpolated $x and $($y.Prop) end\"\n",
"$obj.Method($arg)?.Chained[0]\n",
"Get-WmiObject Win32_BIOS # trailing comment kept\n",
" \n# just trivia\n ",
"'unterminated string still round-trips",
"",
] {
roundtrips(src);
}
}
#[test]
fn node_unparse_recovers_exact_source_with_trivia() {
let src = "$x = Get-WmiObject -Class Win32_BIOS # keep\n";
let tree = parse_with_tokens(src);
let mut command_text = None;
tree.root.walk(&mut |n| {
if n.label == "Command" {
command_text = Some(n.unparse_lossless(&tree));
}
});
let text = command_text.expect("a Command node");
assert!(text.contains("Get-WmiObject"));
assert!(text.contains("-Class"));
assert!(text.contains("Win32_BIOS"));
assert!(text.contains("Get-WmiObject -Class Win32_BIOS"));
}
#[test]
fn token_ranges_nest_within_parents() {
let src = "if ($x -gt 0) { Write-Output $x }\n";
let tree = parse_with_tokens(src);
fn check(node: &TreeNode) {
for child in &node.children {
if !child.range.is_empty() && !node.range.is_empty() {
assert!(
child.range.first >= node.range.first && child.range.end <= node.range.end,
"child {} [{},{}) escapes parent {} [{},{})",
child.label,
child.range.first,
child.range.end,
node.label,
node.range.first,
node.range.end,
);
}
check(child);
}
}
check(&tree.root);
}
#[test]
fn source_span_slices_back_to_node_text() {
let src = "Write-Output 'hello'\n";
let tree = parse_with_tokens(src);
tree.root.walk(&mut |n| {
if let Some(span) = n.source_span(&tree) {
assert_eq!(span.slice(src), n.unparse_lossless(&tree));
}
});
}
#[test]
fn errors_are_surfaced_from_v1() {
let tree = parse_with_tokens("}\n");
assert_eq!(tree.unparse_lossless(), "}\n");
}
#[test]
fn ranges_are_well_formed_under_fuzz() {
fn nest_and_footprint(node: &TreeNode, tree: &Tree, src: &str) {
if let Some(span) = node.source_span(tree) {
assert!(span.start <= span.end && span.end <= src.len());
assert_eq!(
span.slice(src),
node.unparse_lossless(tree),
"footprint mismatch on {}",
node.label
);
}
for child in &node.children {
if !child.range.is_empty() && !node.range.is_empty() {
assert!(
child.range.first >= node.range.first && child.range.end <= node.range.end,
"{} escapes parent {}",
child.label,
node.label
);
}
nest_and_footprint(child, tree, src);
}
}
let charset: Vec<char> = "$@{}()[]| ;,.\"'`#-=+*/<>?:&\nabcXYZ012_".chars().collect();
let mut state: u64 = 0x1234_5678_9abc_def0;
let mut next = || {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
state
};
for _ in 0..1000 {
let len = (next() % 60) as usize;
let body: String = (0..len)
.map(|_| charset[(next() as usize) % charset.len()])
.collect();
let src = match next() % 3 {
0 => format!("# header\n{body}"),
1 => format!("\n\n {body}"),
_ => body,
};
let tree = parse_with_tokens(&src);
assert_eq!(tree.unparse_lossless(), src, "root round trip for {src:?}");
nest_and_footprint(&tree.root, &tree, &src);
}
}
#[test]
fn top_level_statements_map_to_their_own_source() {
let src = "$a = 1\nGet-WmiObject Win32_BIOS # legacy\n$b = 2\n";
let tree = parse_with_tokens(src);
let stmts: Vec<(&str, String)> = tree
.root
.children
.iter()
.map(|s| (s.label, s.unparse_lossless(&tree)))
.collect();
assert_eq!(
stmts,
vec![
("AssignmentStatement", "$a = 1\n".to_string()),
(
"Command",
"Get-WmiObject Win32_BIOS # legacy\n".to_string()
),
("AssignmentStatement", "$b = 2\n".to_string()),
]
);
}
#[test]
fn walk_zipped_pairs_typed_nodes_with_ranges() {
use crate::v1::ast::{AstNode, NodeInfo};
let src = "Get-WmiObject -Class Win32_BIOS | Out-Null\n";
let tree = parse_with_tokens(src);
let mut commands = Vec::new();
tree.walk_zipped(&mut |ast, node, _ancestors| {
assert_eq!(ast.label(), node.label, "zip drifted");
if let AstNode::Command(c) = ast {
commands.push((c.name.clone(), node.range.first));
}
});
assert_eq!(commands.len(), 2);
assert_eq!(commands[0].0, "Get-WmiObject");
assert_eq!(tree.tokens[commands[0].1].value, "Get-WmiObject");
assert_eq!(commands[1].0, "Out-Null");
assert_eq!(tree.tokens[commands[1].1].value, "Out-Null");
}
#[test]
fn ancestors_expose_context() {
let src = "function Outer { if ($x) { $inner } }\n$top\n";
let tree = parse_with_tokens(src);
let mut inner_path: Option<Vec<&'static str>> = None;
let mut top_path: Option<Vec<&'static str>> = None;
tree.root.walk_with_ancestors(&mut |node, ancestors| {
if node.label == "Variable" {
let path: Vec<&'static str> = ancestors.iter().map(|a| a.label).collect();
if node.scalars.contains("inner") {
inner_path = Some(path);
} else if node.scalars.contains("top") {
top_path = Some(path);
}
}
});
let inner_path = inner_path.expect("found $inner");
assert!(inner_path.contains(&"FunctionDefinition"));
assert!(inner_path.contains(&"IfStatement"));
let top_path = top_path.expect("found $top");
assert!(!top_path.contains(&"FunctionDefinition"));
let mut saw = false;
tree.walk_zipped(&mut |_, node, ancestors| {
if node.label == "Variable" && node.scalars.contains("inner") {
saw = ancestors
.iter()
.any(|(_, n)| n.label == "FunctionDefinition");
}
});
assert!(saw);
}
}