use super::{error::PoolError, node::Node, types::NodeId};
use im::HashMap;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct NodePoolInner {
pub root_id: NodeId,
pub nodes: im::HashMap<NodeId, Arc<Node>>, pub parent_map: im::HashMap<NodeId, NodeId>,
}
impl NodePoolInner {
pub fn update_attr(
&self,
id: &NodeId,
values: &HashMap<String, String>,
) -> Result<Self, PoolError> {
if !self.nodes.contains_key(id) {
return Err(PoolError::NodeNotFound(id.clone()));
}
let node = self.nodes.get(id).unwrap();
let mut cope_node = node.clone().as_ref().clone();
cope_node.attrs.extend(values.clone());
let nodes = self.nodes.update(id.clone(), Arc::new(cope_node));
Ok(NodePoolInner {
nodes,
parent_map: self.parent_map.clone(),
root_id: self.root_id.clone(),
})
}
}
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct NodePool {
pub inner: Arc<NodePoolInner>,
}
unsafe impl Send for NodePool {}
unsafe impl Sync for NodePool {}
impl NodePool {
pub fn size(&self) -> usize {
self.inner.nodes.len()
}
pub fn from(
nodes: Vec<Node>,
root_id: NodeId,
) -> Self {
let mut nodes_ref = HashMap::new();
let mut parent_map_ref = HashMap::new();
for node in nodes.into_iter() {
for child_id in &node.content {
parent_map_ref.insert(child_id.clone(), node.id.clone());
}
nodes_ref.insert(node.id.clone(), Arc::new(node));
}
NodePool {
inner: Arc::new(NodePoolInner {
nodes: nodes_ref,
parent_map: parent_map_ref,
root_id,
}),
}
}
pub fn get_node(
&self,
id: &NodeId,
) -> Option<&Arc<Node>> {
self.inner.nodes.get(id)
}
pub fn contains_node(
&self,
id: &NodeId,
) -> bool {
self.inner.nodes.contains_key(id)
}
pub fn children(
&self,
parent_id: &NodeId,
) -> Option<&im::Vector<NodeId>> {
self.get_node(parent_id).map(|n| &n.content)
}
pub fn descendants(
&self,
parent_id: &NodeId,
) -> Vec<&Node> {
let mut result: Vec<&Node> = Vec::new();
self._collect_descendants(parent_id, &mut result);
result
}
fn _collect_descendants<'a>(
&'a self,
parent_id: &NodeId,
result: &mut Vec<&'a Node>,
) {
if let Some(children) = self.children(parent_id) {
for child_id in children {
if let Some(child) = self.get_node(child_id) {
result.push(child);
self._collect_descendants(child_id, result);
}
}
}
}
pub fn parent_id(
&self,
child_id: &NodeId,
) -> Option<&NodeId> {
self.inner.parent_map.get(child_id)
}
pub fn ancestors(
&self,
child_id: &NodeId,
) -> Vec<&Arc<Node>> {
let mut chain = Vec::new();
let mut current_id = child_id;
while let Some(parent_id) = self.parent_id(current_id) {
if let Some(parent) = self.get_node(parent_id) {
chain.push(parent);
current_id = parent_id;
} else {
break;
}
}
chain
}
pub fn validate_hierarchy(&self) -> Result<(), PoolError> {
for (child_id, parent_id) in &self.inner.parent_map {
if !self.contains_node(parent_id) {
return Err(PoolError::OrphanNode(child_id.clone()));
}
if let Some(children) = self.children(parent_id) {
if !children.contains(child_id) {
return Err(PoolError::InvalidParenting {
child: child_id.clone(),
alleged_parent: parent_id.clone(),
});
}
}
}
Ok(())
}
pub fn filter_nodes<P>(
&self,
predicate: P,
) -> Vec<&Arc<Node>>
where
P: Fn(&Node) -> bool,
{
self.inner.nodes.values().filter(|n| predicate(n)).collect()
}
pub fn find_node<P>(
&self,
predicate: P,
) -> Option<&Arc<Node>>
where
P: Fn(&Node) -> bool,
{
self.inner.nodes.values().find(|n| predicate(n))
}
pub fn get_node_depth(
&self,
node_id: &NodeId,
) -> Option<usize> {
let mut depth = 0;
let mut current_id = node_id;
while let Some(parent_id) = self.parent_id(current_id) {
depth += 1;
current_id = parent_id;
}
Some(depth)
}
pub fn get_node_path(
&self,
node_id: &NodeId,
) -> Vec<NodeId> {
let mut path = Vec::new();
let mut current_id = node_id;
while let Some(parent_id) = self.parent_id(current_id) {
path.push(current_id.clone());
current_id = parent_id;
}
path.push(current_id.clone());
path.reverse();
path
}
pub fn is_leaf(
&self,
node_id: &NodeId,
) -> bool {
if let Some(children) = self.children(node_id) {
children.is_empty()
} else {
true
}
}
pub fn get_siblings(
&self,
node_id: &NodeId,
) -> Vec<NodeId> {
if let Some(parent_id) = self.parent_id(node_id) {
if let Some(children) = self.children(parent_id) {
return children
.iter()
.filter(|&id| id != node_id)
.cloned()
.collect();
}
}
Vec::new()
}
pub fn get_all_siblings(
&self,
node_id: &NodeId,
) -> Vec<NodeId> {
if let Some(parent_id) = self.parent_id(node_id) {
if let Some(children) = self.children(parent_id) {
return children.iter().cloned().collect();
}
}
Vec::new()
}
pub fn get_subtree_size(
&self,
node_id: &NodeId,
) -> usize {
let mut size = 1; if let Some(children) = self.children(node_id) {
for child_id in children {
size += self.get_subtree_size(child_id);
}
}
size
}
pub fn is_ancestor(
&self,
ancestor_id: &NodeId,
descendant_id: &NodeId,
) -> bool {
let mut current_id = descendant_id;
while let Some(parent_id) = self.parent_id(current_id) {
if parent_id == ancestor_id {
return true;
}
current_id = parent_id;
}
false
}
pub fn get_lowest_common_ancestor(
&self,
node1_id: &NodeId,
node2_id: &NodeId,
) -> Option<NodeId> {
let path1 = self.get_node_path(node1_id);
let path2 = self.get_node_path(node2_id);
for ancestor_id in path1.iter().rev() {
if path2.contains(ancestor_id) {
return Some(ancestor_id.clone());
}
}
None
}
}