use std::fmt;
use rustc_hash::FxHashSet;
use smallvec::SmallVec;
use super::types::RuleId;
use crate::lattice::{EdgeId, NodeId};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ForestNodeId(pub u32);
impl ForestNodeId {
pub fn new(id: u32) -> Self {
Self(id)
}
pub fn id(&self) -> u32 {
self.0
}
}
impl fmt::Display for ForestNodeId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "F{}", self.0)
}
}
#[derive(Clone, Debug)]
pub struct ForestNode {
pub rule: RuleId,
pub start: NodeId,
pub end: NodeId,
pub children: SmallVec<[ForestChild; 4]>,
}
impl ForestNode {
pub fn new(rule: RuleId, start: NodeId, end: NodeId) -> Self {
Self {
rule,
start,
end,
children: SmallVec::new(),
}
}
pub fn add_child(&mut self, child: ForestChild) {
self.children.push(child);
}
pub fn add_derivation(&mut self, children: SmallVec<[ForestNodeId; 4]>) {
self.children.push(ForestChild::Derivation(children));
}
pub fn add_terminal(&mut self, edge: EdgeId) {
self.children.push(ForestChild::Terminal(edge));
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ForestChild {
Derivation(SmallVec<[ForestNodeId; 4]>),
Terminal(EdgeId),
}
#[derive(Clone, Debug, Default)]
pub struct ParseForest {
nodes: Vec<ForestNode>,
roots: FxHashSet<ForestNodeId>,
}
impl ParseForest {
pub fn new() -> Self {
Self::default()
}
pub fn add_node(&mut self, node: ForestNode) -> ForestNodeId {
let id = ForestNodeId::new(self.nodes.len() as u32);
self.nodes.push(node);
id
}
pub fn add_root(&mut self, id: ForestNodeId) {
self.roots.insert(id);
}
pub fn node(&self, id: ForestNodeId) -> Option<&ForestNode> {
self.nodes.get(id.0 as usize)
}
pub fn node_mut(&mut self, id: ForestNodeId) -> Option<&mut ForestNode> {
self.nodes.get_mut(id.0 as usize)
}
pub fn roots(&self) -> impl Iterator<Item = ForestNodeId> + '_ {
self.roots.iter().copied()
}
pub fn is_empty(&self) -> bool {
self.roots.is_empty()
}
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
pub fn num_roots(&self) -> usize {
self.roots.len()
}
pub fn best_parse(&self) -> Option<ParseTree> {
self.roots().next().and_then(|root| self.extract_tree(root))
}
pub fn all_parses(&self, limit: usize) -> Vec<ParseTree> {
let mut trees = Vec::new();
for root in self.roots() {
if trees.len() >= limit {
break;
}
if let Some(tree) = self.extract_tree(root) {
trees.push(tree);
}
}
trees
}
fn extract_tree(&self, root: ForestNodeId) -> Option<ParseTree> {
let node = self.node(root)?;
let mut tree = ParseTree {
rule: node.rule,
start: node.start,
end: node.end,
children: Vec::new(),
};
for child in &node.children {
match child {
ForestChild::Derivation(kids) => {
for &kid_id in kids {
if let Some(kid_tree) = self.extract_tree(kid_id) {
tree.children.push(ParseTreeChild::Tree(Box::new(kid_tree)));
}
}
break; }
ForestChild::Terminal(edge) => {
tree.children.push(ParseTreeChild::Terminal(*edge));
}
}
}
Some(tree)
}
pub fn collect_used_edges(&self) -> FxHashSet<EdgeId> {
let mut edges = FxHashSet::default();
fn collect(forest: &ParseForest, node_id: ForestNodeId, edges: &mut FxHashSet<EdgeId>) {
if let Some(node) = forest.node(node_id) {
for child in &node.children {
match child {
ForestChild::Derivation(kids) => {
for &kid_id in kids {
collect(forest, kid_id, edges);
}
}
ForestChild::Terminal(edge) => {
edges.insert(*edge);
}
}
}
}
}
for root in self.roots() {
collect(self, root, &mut edges);
}
edges
}
}
#[derive(Clone, Debug)]
pub struct ParseTree {
pub rule: RuleId,
pub start: NodeId,
pub end: NodeId,
pub children: Vec<ParseTreeChild>,
}
#[derive(Clone, Debug)]
pub enum ParseTreeChild {
Tree(Box<ParseTree>),
Terminal(EdgeId),
}
impl ParseTree {
pub fn depth(&self) -> usize {
1 + self
.children
.iter()
.map(|c| match c {
ParseTreeChild::Tree(t) => t.depth(),
ParseTreeChild::Terminal(_) => 0,
})
.max()
.unwrap_or(0)
}
pub fn size(&self) -> usize {
1 + self
.children
.iter()
.map(|c| match c {
ParseTreeChild::Tree(t) => t.size(),
ParseTreeChild::Terminal(_) => 1,
})
.sum::<usize>()
}
pub fn edges(&self) -> Vec<EdgeId> {
let mut result = Vec::new();
self.collect_edges(&mut result);
result
}
fn collect_edges(&self, result: &mut Vec<EdgeId>) {
for child in &self.children {
match child {
ParseTreeChild::Tree(t) => t.collect_edges(result),
ParseTreeChild::Terminal(e) => result.push(*e),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_forest_node_id() {
let id = ForestNodeId::new(5);
assert_eq!(id.id(), 5);
assert_eq!(format!("{}", id), "F5");
}
#[test]
fn test_forest_node() {
let mut node = ForestNode::new(RuleId::new(0), NodeId(0), NodeId(2));
assert_eq!(node.rule, RuleId::new(0));
assert!(node.children.is_empty());
node.add_terminal(EdgeId(1));
assert_eq!(node.children.len(), 1);
}
#[test]
fn test_parse_forest_creation() {
let mut forest = ParseForest::new();
assert!(forest.is_empty());
let node = ForestNode::new(RuleId::new(0), NodeId(0), NodeId(1));
let id = forest.add_node(node);
forest.add_root(id);
assert!(!forest.is_empty());
assert_eq!(forest.num_nodes(), 1);
assert_eq!(forest.num_roots(), 1);
}
#[test]
fn test_best_parse() {
let mut forest = ParseForest::new();
let mut root = ForestNode::new(RuleId::new(0), NodeId(0), NodeId(2));
root.add_terminal(EdgeId(0));
root.add_terminal(EdgeId(1));
let root_id = forest.add_node(root);
forest.add_root(root_id);
let tree = forest.best_parse().expect("should have parse");
assert_eq!(tree.rule, RuleId::new(0));
assert_eq!(tree.children.len(), 2);
}
#[test]
fn test_parse_tree_metrics() {
let tree = ParseTree {
rule: RuleId::new(0),
start: NodeId(0),
end: NodeId(3),
children: vec![
ParseTreeChild::Tree(Box::new(ParseTree {
rule: RuleId::new(1),
start: NodeId(0),
end: NodeId(1),
children: vec![ParseTreeChild::Terminal(EdgeId(0))],
})),
ParseTreeChild::Terminal(EdgeId(1)),
],
};
assert_eq!(tree.depth(), 2);
assert_eq!(tree.size(), 4); assert_eq!(tree.edges().len(), 2);
}
#[test]
fn test_collect_used_edges() {
let mut forest = ParseForest::new();
let mut root = ForestNode::new(RuleId::new(0), NodeId(0), NodeId(2));
root.add_terminal(EdgeId(0));
root.add_terminal(EdgeId(1));
let root_id = forest.add_node(root);
forest.add_root(root_id);
let edges = forest.collect_used_edges();
assert_eq!(edges.len(), 2);
assert!(edges.contains(&EdgeId(0)));
assert!(edges.contains(&EdgeId(1)));
}
#[test]
fn test_all_parses() {
let mut forest = ParseForest::new();
let root1 = ForestNode::new(RuleId::new(0), NodeId(0), NodeId(1));
let root2 = ForestNode::new(RuleId::new(1), NodeId(0), NodeId(1));
let id1 = forest.add_node(root1);
let id2 = forest.add_node(root2);
forest.add_root(id1);
forest.add_root(id2);
let trees = forest.all_parses(10);
assert_eq!(trees.len(), 2);
}
}