1use std::{ops::Index, sync::Arc, num::NonZeroUsize};
2use std::hash::{Hash, Hasher};
3use std::collections::hash_map::DefaultHasher;
4use im::Vector;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8use crate::node_type::NodeEnum;
9use crate::{
10 error::PoolError,
11 mark::Mark,
12 node::Node,
13 ops::{AttrsRef, MarkRef, NodeRef},
14 types::NodeId,
15};
16
17#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
18pub struct Tree {
19 pub root_id: NodeId,
20 pub nodes: Vector<im::HashMap<NodeId, Arc<Node>>>, pub parent_map: im::HashMap<NodeId, NodeId>,
22}
23
24impl Tree {
25 pub fn get_shard_index(
26 &self,
27 id: &NodeId,
28 ) -> usize {
29 let mut hasher = DefaultHasher::new();
30 id.hash(&mut hasher);
31 (hasher.finish() as usize) % self.nodes.len()
32 }
33
34 pub fn contains_node(
35 &self,
36 id: &NodeId,
37 ) -> bool {
38 let shard_index = self.get_shard_index(id);
39 self.nodes[shard_index].contains_key(id)
40 }
41
42 pub fn get_node(
43 &self,
44 id: &NodeId,
45 ) -> Option<Arc<Node>> {
46 let shard_index = self.get_shard_index(id);
47 self.nodes[shard_index].get(id).cloned()
48 }
49
50 pub fn get_parent_node(
51 &self,
52 id: &NodeId,
53 ) -> Option<Arc<Node>> {
54 self.parent_map.get(id).and_then(|parent_id| {
55 let shard_index = self.get_shard_index(parent_id);
56 self.nodes[shard_index].get(parent_id).cloned()
57 })
58 }
59 pub fn from(nodes: NodeEnum) -> Self {
60 let num_shards = std::cmp::max(
61 std::thread::available_parallelism()
62 .map(NonZeroUsize::get)
63 .unwrap_or(2),
64 2,
65 );
66 let mut shards = Vector::from(vec![im::HashMap::new(); num_shards]);
67 let mut parent_map = im::HashMap::new();
68 let (root_node, children) = nodes.into_parts();
69 let root_id = root_node.id.clone();
70
71 let mut hasher = DefaultHasher::new();
72 root_id.hash(&mut hasher);
73 let shard_index = (hasher.finish() as usize) % num_shards;
74 shards[shard_index] =
75 shards[shard_index].update(root_id.clone(), Arc::new(root_node));
76
77 fn process_children(
78 children: Vec<NodeEnum>,
79 parent_id: &NodeId,
80 shards: &mut Vector<im::HashMap<NodeId, Arc<Node>>>,
81 parent_map: &mut im::HashMap<NodeId, NodeId>,
82 num_shards: usize,
83 ) {
84 for child in children {
85 let (node, grand_children) = child.into_parts();
86 let node_id = node.id.clone();
87 let mut hasher = DefaultHasher::new();
88 node_id.hash(&mut hasher);
89 let shard_index = (hasher.finish() as usize) % num_shards;
90 shards[shard_index] =
91 shards[shard_index].update(node_id.clone(), Arc::new(node));
92 parent_map.insert(node_id.clone(), parent_id.clone());
93
94 process_children(
96 grand_children,
97 &node_id,
98 shards,
99 parent_map,
100 num_shards,
101 );
102 }
103 }
104
105 process_children(
106 children,
107 &root_id,
108 &mut shards,
109 &mut parent_map,
110 num_shards,
111 );
112
113 Self { root_id, nodes: shards, parent_map }
114 }
115
116 pub fn new(root: Node) -> Self {
117 let num_shards = std::cmp::max(
118 std::thread::available_parallelism()
119 .map(NonZeroUsize::get)
120 .unwrap_or(2),
121 2,
122 );
123 let mut nodes = Vector::from(vec![im::HashMap::new(); num_shards]);
124 let root_id = root.id.clone();
125 let mut hasher = DefaultHasher::new();
126 root_id.hash(&mut hasher);
127 let shard_index = (hasher.finish() as usize) % num_shards;
128 nodes[shard_index] =
129 nodes[shard_index].update(root_id.clone(), Arc::new(root));
130 Self { root_id, nodes, parent_map: im::HashMap::new() }
131 }
132
133 pub fn update_attr(
134 &mut self,
135 id: &NodeId,
136 new_values: im::HashMap<String, Value>,
137 ) -> Result<(), PoolError> {
138 let shard_index = self.get_shard_index(id);
139 let node = self.nodes[shard_index]
140 .get(id)
141 .ok_or(PoolError::NodeNotFound(id.clone()))?;
142 let old_values = node.attrs.clone();
143 let mut new_node = node.as_ref().clone();
144 let new_attrs = old_values.update(new_values);
145 new_node.attrs = new_attrs.clone();
146 self.nodes[shard_index] =
147 self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
148 Ok(())
149 }
150
151 pub fn add(
162 &mut self,
163 nodes: NodeEnum,
164 ) -> Result<(), PoolError> {
165 let (mut node, children) = nodes.into_parts();
167 let node_id = node.id.clone();
168
169 let parent_shard_index = self.get_shard_index(&node_id);
171 let _ = self.nodes[parent_shard_index]
172 .get(&node_id)
173 .ok_or(PoolError::ParentNotFound(node_id.clone()))?;
174
175 let zenliang: Vector<String> =
177 children.iter().map(|n| n.0.id.clone()).collect();
178 node.content.extend(zenliang);
179
180 let shard_index = self.get_shard_index(&node_id);
182 self.nodes[shard_index] =
183 self.nodes[shard_index].update(node_id.clone(), Arc::new(node));
184
185 let mut node_queue = Vec::new();
187 node_queue.push((children, node_id.clone()));
188 while let Some((current_children, parent_id)) = node_queue.pop() {
189 for child in current_children {
190 let (mut child_node, grand_children) = child.into_parts();
192 let current_node_id = child_node.id.clone();
193
194 let zenliang: Vector<String> =
196 grand_children.iter().map(|n| n.0.id.clone()).collect();
197 child_node.content.extend(zenliang);
198
199 let shard_index = self.get_shard_index(¤t_node_id);
201 self.nodes[shard_index] = self.nodes[shard_index]
202 .update(current_node_id.clone(), Arc::new(child_node));
203
204 self.parent_map
206 .insert(current_node_id.clone(), parent_id.clone());
207
208 node_queue.push((grand_children, current_node_id.clone()));
210 }
211 }
212 Ok(())
213 }
214
215 pub fn add_node(
216 &mut self,
217 parent_id: &NodeId,
218 nodes: &Vec<Node>,
219 ) -> Result<(), PoolError> {
220 let parent_shard_index = self.get_shard_index(parent_id);
221 let parent = self.nodes[parent_shard_index]
222 .get(parent_id)
223 .ok_or(PoolError::ParentNotFound(parent_id.clone()))?;
224 let mut new_parent = parent.as_ref().clone();
225 new_parent.content.push_back(nodes[0].id.clone());
226 self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
227 .update(parent_id.clone(), Arc::new(new_parent));
228 self.parent_map.insert(nodes[0].id.clone(), parent_id.clone());
229 for node in nodes {
230 let shard_index = self.get_shard_index(&node.id);
231 for child_id in &node.content {
232 self.parent_map.insert(child_id.clone(), node.id.clone());
233 }
234 self.nodes[shard_index] = self.nodes[shard_index]
235 .update(node.id.clone(), Arc::new(node.clone()));
236 }
237 Ok(())
238 }
239
240 pub fn node(
241 &mut self,
242 key: &str,
243 ) -> NodeRef<'_> {
244 NodeRef::new(self, key.to_string())
245 }
246 pub fn mark(
247 &mut self,
248 key: &str,
249 ) -> MarkRef<'_> {
250 MarkRef::new(self, key.to_string())
251 }
252 pub fn attrs(
253 &mut self,
254 key: &str,
255 ) -> AttrsRef<'_> {
256 AttrsRef::new(self, key.to_string())
257 }
258
259 pub fn children(
260 &self,
261 parent_id: &NodeId,
262 ) -> Option<im::Vector<NodeId>> {
263 self.get_node(parent_id).map(|n| n.content.clone())
264 }
265
266 pub fn children_node(
267 &self,
268 parent_id: &NodeId,
269 ) -> Option<im::Vector<Arc<Node>>> {
270 self.children(parent_id)
271 .map(|ids| ids.iter().filter_map(|id| self.get_node(id)).collect())
272 }
273
274 pub fn children_count(
275 &self,
276 parent_id: &NodeId,
277 ) -> usize {
278 self.get_node(parent_id).map(|n| n.content.len()).unwrap_or(0)
279 }
280
281 pub fn remove_mark(
282 &mut self,
283 id: &NodeId,
284 mark: Mark,
285 ) -> Result<(), PoolError> {
286 let shard_index = self.get_shard_index(id);
287 let node = self.nodes[shard_index]
288 .get(id)
289 .ok_or(PoolError::NodeNotFound(id.clone()))?;
290 let mut new_node = node.as_ref().clone();
291 new_node.marks =
292 new_node.marks.iter().filter(|&m| !m.eq(&mark)).cloned().collect();
293 self.nodes[shard_index] =
294 self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
295 Ok(())
296 }
297
298 pub fn add_mark(
299 &mut self,
300 id: &NodeId,
301 marks: &Vec<Mark>,
302 ) -> Result<(), PoolError> {
303 let shard_index = self.get_shard_index(id);
304 let node = self.nodes[shard_index]
305 .get(id)
306 .ok_or(PoolError::NodeNotFound(id.clone()))?;
307 let mut new_node = node.as_ref().clone();
308 new_node.marks.extend(marks.clone());
309 self.nodes[shard_index] =
310 self.nodes[shard_index].update(id.clone(), Arc::new(new_node));
311 Ok(())
312 }
313 pub fn replace_node(
314 &mut self,
315 node_id: NodeId,
316 nodes: &Vec<Node>,
317 ) -> Result<(), PoolError> {
318 let shard_index = self.get_shard_index(&node_id);
319 let _ = self.nodes[shard_index]
320 .get(&node_id)
321 .ok_or(PoolError::NodeNotFound(node_id.clone()))?;
322 if nodes[0].id != node_id {
323 return Err(PoolError::InvalidNodeId {
324 nodeid: node_id,
325 new_node_id: nodes[0].id.clone(),
326 });
327 }
328 let _ = self.add_node(&node_id, nodes)?;
329 Ok(())
330 }
331
332 pub fn move_node(
333 &mut self,
334 source_parent_id: &NodeId,
335 target_parent_id: &NodeId,
336 node_id: &NodeId,
337 position: Option<usize>,
338 ) -> Result<(), PoolError> {
339 let source_shard_index = self.get_shard_index(source_parent_id);
340 let target_shard_index = self.get_shard_index(target_parent_id);
341 let node_shard_index = self.get_shard_index(node_id);
342 let source_parent = self.nodes[source_shard_index]
343 .get(source_parent_id)
344 .ok_or(PoolError::ParentNotFound(source_parent_id.clone()))?;
345 let target_parent = self.nodes[target_shard_index]
346 .get(target_parent_id)
347 .ok_or(PoolError::ParentNotFound(target_parent_id.clone()))?;
348 let _node = self.nodes[node_shard_index]
349 .get(node_id)
350 .ok_or(PoolError::NodeNotFound(node_id.clone()))?;
351 if !source_parent.content.contains(node_id) {
352 return Err(PoolError::InvalidParenting {
353 child: node_id.clone(),
354 alleged_parent: source_parent_id.clone(),
355 });
356 }
357 let mut new_source_parent = source_parent.as_ref().clone();
358 new_source_parent.content = new_source_parent
359 .content
360 .iter()
361 .filter(|&id| id != node_id)
362 .cloned()
363 .collect();
364 let mut new_target_parent = target_parent.as_ref().clone();
365 if let Some(pos) = position {
366 if pos <= new_target_parent.content.len() {
367 let mut new_content = im::Vector::new();
368 for (i, child_id) in
369 new_target_parent.content.iter().enumerate()
370 {
371 if i == pos {
372 new_content.push_back(node_id.clone());
373 }
374 new_content.push_back(child_id.clone());
375 }
376 if pos == new_target_parent.content.len() {
377 new_content.push_back(node_id.clone());
378 }
379 new_target_parent.content = new_content;
380 } else {
381 new_target_parent.content.push_back(node_id.clone());
382 }
383 } else {
384 new_target_parent.content.push_back(node_id.clone());
385 }
386 self.nodes[source_shard_index] = self.nodes[source_shard_index]
387 .update(source_parent_id.clone(), Arc::new(new_source_parent));
388 self.nodes[target_shard_index] = self.nodes[target_shard_index]
389 .update(target_parent_id.clone(), Arc::new(new_target_parent));
390 self.parent_map.insert(node_id.clone(), target_parent_id.clone());
391 Ok(())
392 }
393
394 pub fn remove_node(
395 &mut self,
396 parent_id: &NodeId,
397 nodes: Vec<NodeId>,
398 ) -> Result<(), PoolError> {
399 let parent_shard_index = self.get_shard_index(parent_id);
400 let parent = self.nodes[parent_shard_index]
401 .get(parent_id)
402 .ok_or(PoolError::ParentNotFound(parent_id.clone()))?;
403 if nodes.contains(&self.root_id) {
404 return Err(PoolError::CannotRemoveRoot);
405 }
406 for node_id in &nodes {
407 if !parent.content.contains(node_id) {
408 return Err(PoolError::InvalidParenting {
409 child: node_id.clone(),
410 alleged_parent: parent_id.clone(),
411 });
412 }
413 }
414 let nodes_to_remove: std::collections::HashSet<_> =
415 nodes.iter().collect();
416 let filtered_children: im::Vector<NodeId> = parent
417 .as_ref()
418 .content
419 .iter()
420 .filter(|&id| !nodes_to_remove.contains(id))
421 .cloned()
422 .collect();
423 let mut parent_node = parent.as_ref().clone();
424 parent_node.content = filtered_children;
425 self.nodes[parent_shard_index] = self.nodes[parent_shard_index]
426 .update(parent_id.clone(), Arc::new(parent_node));
427 let mut remove_nodes = Vec::new();
428 for node_id in nodes {
429 self.remove_subtree(&node_id, &mut remove_nodes)?;
430 }
431 Ok(())
432 }
433
434 fn remove_subtree(
435 &mut self,
436 node_id: &NodeId,
437 remove_nodes: &mut Vec<Node>,
438 ) -> Result<(), PoolError> {
439 if node_id == &self.root_id {
440 return Err(PoolError::CannotRemoveRoot);
441 }
442 let shard_index = self.get_shard_index(node_id);
443 let _ = self.nodes[shard_index]
444 .get(node_id)
445 .ok_or(PoolError::NodeNotFound(node_id.clone()))?;
446 if let Some(children) = self.children(node_id) {
447 for child_id in children {
448 self.remove_subtree(&child_id, remove_nodes)?;
449 }
450 }
451 self.parent_map.remove(node_id);
452 if let Some(remove_node) = self.nodes[shard_index].remove(node_id) {
453 remove_nodes.push(remove_node.as_ref().clone());
454 }
455 Ok(())
456 }
457}
458
459impl Index<&NodeId> for Tree {
460 type Output = Arc<Node>;
461 fn index(
462 &self,
463 index: &NodeId,
464 ) -> &Self::Output {
465 let shard_index = self.get_shard_index(index);
466 self.nodes[shard_index].get(index).expect("Node not found")
467 }
468}
469
470impl Index<&str> for Tree {
471 type Output = Arc<Node>;
472 fn index(
473 &self,
474 index: &str,
475 ) -> &Self::Output {
476 let node_id = NodeId::from(index);
477 let shard_index = self.get_shard_index(&node_id);
478 self.nodes[shard_index].get(&node_id).expect("Node not found")
479 }
480}