use crate::error::PoolResult;
use crate::{node_definition::NodeTree, tree::Tree};
use super::{error::error_helpers, node::Node, types::NodeId};
use serde::{Deserialize, Serialize};
use std::time::Instant;
use std::{sync::Arc};
use rayon::prelude::*;
use std::marker::Sync;
use std::sync::atomic::{AtomicUsize, Ordering};
use rpds::{VectorSync};
static POOL_ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
type NodeConditionRef<'a> = Box<dyn Fn(&Node) -> bool + Send + Sync + 'a>;
#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
pub struct NodePool {
inner: Arc<Tree>,
key: String,
}
impl NodePool {
#[cfg_attr(feature = "dev-tracing", tracing::instrument(skip(inner), fields(
crate_name = "model",
node_count = inner.nodes.iter().map(|i| i.values().len()).sum::<usize>()
)))]
pub fn new(inner: Arc<Tree>) -> Arc<NodePool> {
let id = POOL_ID_COUNTER.fetch_add(1, Ordering::SeqCst);
let pool = Self { inner, key: format!("pool_{id}") };
let pool: Arc<NodePool> = Arc::new(pool);
pool
}
pub fn key(&self) -> &str {
&self.key
}
pub fn size(&self) -> usize {
self.inner.nodes.iter().map(|i| i.values().len()).sum()
}
pub fn root(&self) -> Option<&Node> {
self.inner.get_node(&self.inner.root_id)
}
pub fn root_id(&self) -> &NodeId {
&self.inner.root_id
}
pub fn get_inner(&self) -> &Arc<Tree> {
&self.inner
}
pub fn from(nodes: NodeTree) -> Arc<NodePool> {
let id = POOL_ID_COUNTER.fetch_add(1, Ordering::SeqCst);
let pool = Self {
inner: Arc::new(Tree::from(nodes)),
key: format!("pool_{id}"),
};
let pool: Arc<NodePool> = Arc::new(pool);
pool
}
pub fn get_node(
&self,
id: &NodeId,
) -> Option<&Node> {
self.inner.get_node(id)
}
pub fn get_parent_node(
&self,
id: &NodeId,
) -> Option<&Node> {
self.inner.get_parent_node(id)
}
pub fn contains_node(
&self,
id: &NodeId,
) -> bool {
self.inner.contains_node(id)
}
pub fn children(
&self,
parent_id: &NodeId,
) -> Option<VectorSync<NodeId>> {
self.get_node(parent_id).map(|n| n.content.clone())
}
pub fn descendants(
&self,
parent_id: &NodeId,
) -> Vec<Node> {
let mut result: Vec<Node> = Vec::new();
self._collect_descendants(parent_id, &mut result);
result
}
fn _collect_descendants(
&self,
parent_id: &NodeId,
result: &mut Vec<Node>,
) {
if let Some(children) = self.children(parent_id) {
for child_id in &children {
if let Some(child) = self.get_node(child_id) {
result.push(child.clone());
self._collect_descendants(child_id, result);
}
}
}
}
pub fn for_each<F>(
&self,
id: &NodeId,
f: F,
) where
F: Fn(&Node),
{
if let Some(children) = self.children(id) {
for child_id in &children {
if let Some(child) = self.get_node(child_id) {
f(&child);
}
}
}
}
pub fn parent_id(
&self,
child_id: &NodeId,
) -> Option<&NodeId> {
self.inner.parent_map.get(child_id)
}
pub fn ancestors(
&self,
child_id: &NodeId,
) -> Vec<&Node> {
let mut chain = Vec::new();
let mut current_id = child_id;
while let Some(parent_id) = self.parent_id(current_id) {
if let Some(parent) = self.get_node(parent_id) {
chain.push(parent);
current_id = parent_id;
} else {
break;
}
}
chain
}
pub fn validate_hierarchy(&self) -> PoolResult<()> {
for (child_id, parent_id) in &self.inner.parent_map {
if !self.contains_node(parent_id) {
return Err(error_helpers::orphan_node(child_id.clone()));
}
if let Some(children) = self.children(parent_id) {
let has = children.iter().any(|a| a.eq(child_id));
if !has {
return Err(error_helpers::invalid_parenting(
child_id.clone(),
parent_id.clone(),
));
}
}
}
Ok(())
}
pub fn filter_nodes<P>(
&self,
predicate: P,
) -> Vec<&Node>
where
P: Fn(&Node) -> bool,
{
self.get_all_nodes().into_iter().filter(|n| predicate(n)).collect()
}
pub fn find_node<P>(
&self,
predicate: P,
) -> Option<&Node>
where
P: Fn(&Node) -> bool,
{
self.get_all_nodes().into_iter().find(|n| predicate(n))
}
pub fn get_node_depth(
&self,
node_id: &NodeId,
) -> Option<usize> {
let mut depth = 0;
let mut current_id = node_id;
while let Some(parent_id) = self.parent_id(current_id) {
depth += 1;
current_id = parent_id;
}
Some(depth)
}
pub fn get_node_path(
&self,
node_id: &NodeId,
) -> Vec<NodeId> {
let mut path = Vec::new();
let mut current_id = node_id;
while let Some(parent_id) = self.parent_id(current_id) {
path.push(current_id.clone());
current_id = parent_id;
}
path.push(current_id.clone());
path.reverse();
path
}
pub fn resolve(
&self,
node_id: &NodeId,
) -> Vec<&Node> {
let mut result = Vec::new();
let mut current_id = node_id;
loop {
if let Some(node) = self.get_node(current_id) {
result.push(node);
}
if let Some(parent_id) = self.parent_id(current_id) {
current_id = parent_id;
} else {
break;
}
}
result.reverse();
result
}
pub fn is_leaf(
&self,
node_id: &NodeId,
) -> bool {
if let Some(children) = self.children(node_id) {
children.is_empty()
} else {
true
}
}
pub fn get_left_siblings(
&self,
node_id: &NodeId,
) -> Vec<NodeId> {
if let Some(parent_id) = self.parent_id(node_id) {
if let Some(siblings) = self.children(parent_id) {
if let Some(index) =
siblings.iter().position(|id| id == node_id)
{
return siblings.iter().take(index).cloned().collect();
} else {
eprintln!(
"Warning: Node {node_id:?} not found in parent's children list"
);
}
}
}
Vec::new()
}
pub fn get_right_siblings(
&self,
node_id: &NodeId,
) -> Vec<NodeId> {
if let Some(parent_id) = self.parent_id(node_id) {
if let Some(siblings) = self.children(parent_id) {
if let Some(index) =
siblings.iter().position(|id| id == node_id)
{
return siblings.iter().skip(index + 1).cloned().collect();
} else {
eprintln!(
"Warning: Node {node_id:?} not found in parent's children list"
);
}
}
}
Vec::new()
}
pub fn get_left_nodes(
&self,
node_id: &NodeId,
) -> Vec<&Node> {
let siblings = self.get_left_siblings(node_id);
let mut result = Vec::new();
for sibling_id in siblings {
if let Some(node) = self.get_node(&sibling_id) {
result.push(node);
}
}
result
}
pub fn get_right_nodes(
&self,
node_id: &NodeId,
) -> Vec<&Node> {
let siblings = self.get_right_siblings(node_id);
let mut result = Vec::new();
for sibling_id in siblings {
if let Some(node) = self.get_node(&sibling_id) {
result.push(node);
}
}
result
}
pub fn get_all_siblings(
&self,
node_id: &NodeId,
) -> Vec<NodeId> {
if let Some(parent_id) = self.parent_id(node_id) {
if let Some(children) = self.children(parent_id) {
return children.iter().cloned().collect();
}
}
Vec::new()
}
pub fn get_subtree_size(
&self,
node_id: &NodeId,
) -> usize {
let mut size = 1; if let Some(children) = self.children(node_id) {
for child_id in &children {
size += self.get_subtree_size(child_id);
}
}
size
}
pub fn is_ancestor(
&self,
ancestor_id: &NodeId,
descendant_id: &NodeId,
) -> bool {
let mut current_id = descendant_id;
while let Some(parent_id) = self.parent_id(current_id) {
if parent_id == ancestor_id {
return true;
}
current_id = parent_id;
}
false
}
pub fn get_lowest_common_ancestor(
&self,
node1_id: &NodeId,
node2_id: &NodeId,
) -> Option<NodeId> {
let path1 = self.get_node_path(node1_id);
let path2 = self.get_node_path(node2_id);
for ancestor_id in path1.iter().rev() {
if path2.contains(ancestor_id) {
return Some(ancestor_id.clone());
}
}
None
}
pub fn parallel_query<P>(
&self,
predicate: P,
) -> Vec<&Node>
where
P: Fn(&Node) -> bool + Send + Sync,
{
let shards: Vec<_> = self.inner.nodes.iter().collect();
shards
.into_par_iter()
.flat_map(|shard| {
shard
.values()
.filter(|node| predicate(node))
.collect::<Vec<_>>()
})
.collect()
}
fn get_all_nodes(&self) -> Vec<&Node> {
let mut result = Vec::new();
for shard in &self.inner.nodes {
for node in shard.values() {
result.push(node);
}
}
result
}
}
pub struct QueryEngine<'a> {
pool: &'a NodePool,
conditions: Vec<NodeConditionRef<'a>>,
}
impl<'a> QueryEngine<'a> {
pub fn new(pool: &'a NodePool) -> Self {
Self { pool, conditions: Vec::new() }
}
pub fn by_type(
mut self,
node_type: &'a str,
) -> Self {
let node_type = node_type.to_string();
self.conditions.push(Box::new(move |node| node.r#type == node_type));
self
}
pub fn by_attr(
mut self,
key: &'a str,
value: &'a serde_json::Value,
) -> Self {
let key = key.to_string();
let value = value.clone();
self.conditions
.push(Box::new(move |node| node.attrs.get(&key) == Some(&value)));
self
}
pub fn by_mark(
mut self,
mark_type: &'a str,
) -> Self {
let mark_type = mark_type.to_string();
self.conditions.push(Box::new(move |node| {
node.marks.iter().any(|mark| mark.r#type == mark_type)
}));
self
}
pub fn by_child_count(
mut self,
count: usize,
) -> Self {
self.conditions.push(Box::new(move |node| node.content.len() == count));
self
}
pub fn by_depth(
mut self,
depth: usize,
) -> Self {
let pool = self.pool.clone();
self.conditions.push(Box::new(move |node| {
pool.get_node_depth(&node.id) == Some(depth)
}));
self
}
pub fn by_ancestor_type(
mut self,
ancestor_type: &'a str,
) -> Self {
let pool = self.pool.clone();
let ancestor_type = ancestor_type.to_string();
self.conditions.push(Box::new(move |node| {
pool.ancestors(&node.id)
.iter()
.any(|ancestor| ancestor.r#type == ancestor_type)
}));
self
}
pub fn by_descendant_type(
mut self,
descendant_type: &'a str,
) -> Self {
let pool = self.pool.clone();
let descendant_type = descendant_type.to_string();
self.conditions.push(Box::new(move |node| {
pool.descendants(&node.id)
.iter()
.any(|descendant| descendant.r#type == descendant_type)
}));
self
}
pub fn find_all(&self) -> Vec<&Node> {
self.pool
.get_all_nodes()
.into_iter()
.filter(|node| {
self.conditions.iter().all(|condition| condition(node))
})
.collect()
}
pub fn find_first(&self) -> Option<&Node> {
self.pool.get_all_nodes().into_iter().find(|node| {
self.conditions.iter().all(|condition| condition(node))
})
}
pub fn count(&self) -> usize {
self.pool
.get_all_nodes()
.into_iter()
.filter(|node| {
self.conditions.iter().all(|condition| condition(node))
})
.count()
}
pub fn parallel_find_all(&self) -> Vec<&Node> {
let conditions = &self.conditions;
self.pool.parallel_query(|node| {
conditions.iter().all(|condition| condition(node))
})
}
pub fn parallel_find_first(&self) -> Option<&Node> {
let conditions = &self.conditions;
self.pool.get_all_nodes().into_par_iter().find_any(move |node| {
conditions.iter().all(|condition| condition(node))
})
}
pub fn parallel_count(&self) -> usize {
let conditions = &self.conditions;
self.pool
.get_all_nodes()
.into_par_iter()
.filter(move |node| {
conditions.iter().all(|condition| condition(node))
})
.count()
}
}
impl NodePool {
pub fn query(&self) -> QueryEngine<'_> {
QueryEngine::new(self)
}
}
#[derive(Clone, Debug)]
pub struct QueryCacheConfig {
pub capacity: usize,
pub enabled: bool,
}
impl Default for QueryCacheConfig {
fn default() -> Self {
Self { capacity: 1000, enabled: true }
}
}
#[derive(Clone, Debug)]
pub struct LazyQueryConfig {
pub cache_capacity: usize,
pub index_cache_capacity: usize,
pub cache_enabled: bool,
pub index_build_threshold: usize,
}
impl Default for LazyQueryConfig {
fn default() -> Self {
Self {
cache_capacity: 1000,
index_cache_capacity: 100,
cache_enabled: true,
index_build_threshold: 5,
}
}
}
pub struct LazyQueryEngine<'a> {
pool: &'a NodePool,
}
impl<'a> LazyQueryEngine<'a> {
pub fn new(pool: &'a NodePool) -> Self {
Self { pool: pool }
}
pub fn by_type_lazy(
&'a mut self,
node_type: &str,
) -> Vec<&'a Node> {
let start = Instant::now();
let nodes = self.build_type_index(node_type);
let duration = start.elapsed();
println!(
"实时构建类型索引 '{}', 耗时: {:?}, 节点数: {}",
node_type,
duration,
nodes.len()
);
nodes
}
pub fn by_depth_lazy(
&mut self,
depth: usize,
) -> Vec<&Node> {
let start = Instant::now();
let nodes = self.build_depth_index(depth);
let duration = start.elapsed();
println!(
"实时构建深度索引 {}, 耗时: {:?}, 节点数: {}",
depth,
duration,
nodes.len()
);
nodes
}
pub fn by_mark_lazy(
&'a mut self,
mark_type: &str,
) -> Vec<&'a Node> {
let start = Instant::now();
let nodes = self.build_mark_index(mark_type);
let duration = start.elapsed();
println!(
"实时构建标记索引 '{}', 耗时: {:?}, 节点数: {}",
mark_type,
duration,
nodes.len()
);
nodes
}
pub fn parallel_query(
&'a mut self,
conditions: &[QueryCondition],
) -> Vec<&'a Node> {
let result = self.pool.parallel_query(|node| {
conditions.iter().all(|cond| cond.matches(node))
});
result
}
fn build_type_index(
&self,
node_type: &str,
) -> Vec<&Node> {
self.pool.parallel_query(|node| node.r#type == node_type)
}
fn build_depth_index(
&self,
target_depth: usize,
) -> Vec<&Node> {
self.pool.parallel_query(|node| {
self.pool
.get_node_depth(&node.id)
.map(|depth| depth == target_depth)
.unwrap_or(false)
})
}
fn build_mark_index(
&self,
mark_type: &str,
) -> Vec<&Node> {
self.pool.parallel_query(|node| {
node.marks.iter().any(|mark| mark.r#type == mark_type)
})
}
}
#[derive(Debug, Clone)]
pub enum QueryCondition {
ByType(String),
ByMark(String),
ByAttr { key: String, value: serde_json::Value },
IsLeaf,
HasChildren,
}
impl QueryCondition {
pub fn matches(
&self,
node: &Node,
) -> bool {
match self {
QueryCondition::ByType(type_name) => node.r#type == *type_name,
QueryCondition::ByMark(mark_type) => {
node.marks.iter().any(|mark| mark.r#type == *mark_type)
},
QueryCondition::ByAttr { key, value } => {
node.attrs.get(key) == Some(value)
},
QueryCondition::IsLeaf => node.content.is_empty(),
QueryCondition::HasChildren => !node.content.is_empty(),
}
}
pub fn cache_key(&self) -> String {
match self {
QueryCondition::ByType(t) => format!("type_{t}"),
QueryCondition::ByMark(m) => format!("mark_{m}"),
QueryCondition::ByAttr { key, value } => {
format!(
"attr_{}_{}",
key,
serde_json::to_string(value).unwrap_or_default()
)
},
QueryCondition::IsLeaf => "is_leaf".to_string(),
QueryCondition::HasChildren => "has_children".to_string(),
}
}
}
#[derive(Debug)]
pub struct CacheHitRates {
pub query_cache_size: usize,
pub type_index_cache_size: usize,
pub depth_index_cache_size: usize,
pub mark_index_cache_size: usize,
}
impl NodePool {
pub fn lazy_query(&self) -> LazyQueryEngine<'_> {
LazyQueryEngine::new(self)
}
}