fp_growth/
tree.rs

1//! `Tree` implements the tree data struct in FP-Growth algorithm.
2
3use std::{
4    cell::{Cell, RefCell},
5    collections::HashMap,
6    fmt::Debug,
7    rc::{Rc, Weak},
8    usize,
9};
10
11use crate::ItemType;
12
13type RcNode<T> = Rc<Node<T>>;
14type WeakRcNode<T> = Weak<Node<T>>;
15
16/// `Node<T>` represents the single node in a tree.
17#[derive(Debug)]
18pub struct Node<T> {
19    item: Option<T>,
20    count: Cell<usize>,
21    children: RefCell<Vec<RcNode<T>>>,
22    // Use Weak reference here to prevent the reference cycle.
23    parent: RefCell<WeakRcNode<T>>,
24    // The node's neighbor is the one with the same value that is "to the right"
25    // of it in the tree.
26    neighbor: RefCell<WeakRcNode<T>>,
27}
28
29impl<T: ItemType> PartialEq for Node<T> {
30    fn eq(&self, other: &Node<T>) -> bool {
31        self.item == other.item && self.parent.borrow().upgrade() == other.parent.borrow().upgrade()
32    }
33}
34
35impl<T: ItemType> Node<T> {
36    /// Create a new Node with the given item and count.
37    pub fn new(item: Option<T>, count: usize) -> Node<T> {
38        Node {
39            item,
40            count: Cell::new(count),
41            children: RefCell::new(vec![]),
42            parent: Default::default(),
43            neighbor: Default::default(),
44        }
45    }
46
47    /// Create a new Rc<Node> with the given item and count.
48    pub fn new_rc(item: Option<T>, count: usize) -> RcNode<T> {
49        Rc::new(Self::new(item, count))
50    }
51
52    /// Add the given child Node as a child of this node.
53    pub fn add_child(self: &Rc<Self>, child_node: RcNode<T>) {
54        let mut children = self.children.borrow_mut();
55        if !children.contains(&child_node) {
56            *child_node.parent.borrow_mut() = Rc::downgrade(self);
57            children.push(child_node);
58        }
59    }
60
61    pub fn remove_child(self: &Rc<Self>, child_node: RcNode<T>) {
62        let mut children = self.children.borrow_mut();
63        // for (index, node) in children.clone().into_iter().enumerate() {
64        //     if node == child_node {
65        //         children.remove(index);
66        //     }
67        // }
68        let index = children.iter().position(|x| *x == child_node).unwrap();
69        children.remove(index);
70    }
71
72    /// Check whether this node contains a child node for the given item.
73    /// If so, that node's reference is returned; otherwise, `None` is returned.
74    pub fn search(&self, item: T) -> Option<RcNode<T>> {
75        for node in self.children.borrow().iter() {
76            if let Some(child_node_item) = node.item {
77                if child_node_item == item {
78                    return Some(Rc::clone(node));
79                }
80            }
81        }
82        None
83    }
84
85    /// Increment the count associated with this node's item.
86    pub fn increment(&self, incr_count: usize) {
87        let old_count = self.count.get();
88        self.count.set(old_count + incr_count);
89    }
90
91    /// Print out the node.
92    pub fn print(&self, depth: usize) {
93        let padding = " ".repeat(depth);
94        let node_info;
95        match self.is_root() {
96            true => node_info = "<(root)>".to_string(),
97            false => node_info = format!("<{:?} {} (node)>", self.item, self.count.get()),
98        }
99        println!("{}{}", padding, node_info);
100        for child in self.children.borrow().iter() {
101            child.print(depth + 1);
102        }
103    }
104
105    pub fn item(&self) -> Option<T> {
106        self.item
107    }
108
109    /// Return the count value this node's item holds.
110    pub fn count(&self) -> usize {
111        self.count.get()
112    }
113
114    /// Return this node's neighbor node.
115    pub fn neighbor(&self) -> Option<RcNode<T>> {
116        self.neighbor.borrow().upgrade()
117    }
118
119    /// Return this node's parent node.
120    pub fn parent(&self) -> Option<RcNode<T>> {
121        self.parent.borrow().upgrade()
122    }
123
124    /// Check whether this node is a root node.
125    pub fn is_root(&self) -> bool {
126        self.item == None && self.count.get() == 0
127    }
128
129    /// Check whether this node is a leaf node.
130    pub fn is_leaf(&self) -> bool {
131        self.children.borrow().len() == 0
132    }
133}
134
135type Route<T> = (RefCell<RcNode<T>>, RefCell<RcNode<T>>);
136
137/// `Tree<T>` represents the main tree data struct will be used during the FP-Growth algorithm.
138pub struct Tree<T> {
139    root_node: RefCell<RcNode<T>>,
140    // routes is a HashMap who maintains a mapping which satisfies item -> (Head node, tail node).
141    routes: HashMap<T, Route<T>>,
142}
143
144impl<T: ItemType> Default for Tree<T> {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150impl<T: ItemType> Tree<T> {
151    /// Create a new FP-Growth tree with an empty root node.
152    pub fn new() -> Tree<T> {
153        Tree {
154            root_node: RefCell::new(Node::new_rc(None, 0)),
155            routes: HashMap::new(),
156        }
157    }
158
159    /// Generate a partial tree with the given paths.
160    /// This function will be called during the algorithm.
161    pub fn generate_partial_tree(paths: &[Vec<RcNode<T>>]) -> Tree<T> {
162        let mut partial_tree = Tree::new();
163        let mut leaf_item = None;
164        for path in paths.iter() {
165            // Get leaf_count from the leaf node.
166            leaf_item = Some(path.last().unwrap().item.unwrap());
167            let mut cur_node = Rc::clone(&partial_tree.root_node.borrow());
168            for path_node in path.iter() {
169                match cur_node.search(path_node.item.unwrap()) {
170                    Some(child_node) => {
171                        cur_node = child_node;
172                    }
173                    None => {
174                        let next_node = Node::new_rc(path_node.item, {
175                            let mut count = 0;
176                            if path_node.item == leaf_item {
177                                count = path_node.count.get();
178                            }
179                            count
180                        });
181                        cur_node.add_child(Rc::clone(&next_node));
182                        partial_tree.update_route(Rc::clone(&next_node));
183                        cur_node = next_node;
184                    }
185                }
186            }
187        }
188
189        // Calculate the counts of the non-leaf nodes.
190        for path in partial_tree.generate_prefix_path(leaf_item.unwrap()).iter() {
191            let leaf_count = path.last().unwrap().count.get();
192            for path_node in path[..path.len() - 1].iter() {
193                path_node.increment(leaf_count);
194            }
195        }
196
197        partial_tree
198    }
199
200    /// Iterate the transaction and add every item to the FP-Growth tree.
201    pub fn add_transaction(&mut self, transaction: Vec<T>) {
202        let mut cur_node = Rc::clone(&self.root_node.borrow());
203        for &item in transaction.iter() {
204            match cur_node.search(item) {
205                // There is already a node in this tree for the current
206                // transaction item; reuse it.
207                Some(child_node) => {
208                    child_node.increment(1);
209                    cur_node = child_node;
210                }
211                None => {
212                    let next_node = Node::new_rc(Some(item), 1);
213                    cur_node.add_child(Rc::clone(&next_node));
214                    self.update_route(Rc::clone(&next_node));
215                    cur_node = next_node;
216                }
217            }
218        }
219    }
220
221    /// Update the route table that records the item and its node list.
222    pub fn update_route(&mut self, node: RcNode<T>) {
223        if let Some(item) = node.item {
224            match self.routes.get(&item) {
225                Some((_, tail)) => {
226                    let old_tail = tail.replace(Rc::clone(&node));
227                    *old_tail.neighbor.borrow_mut() = Rc::downgrade(&node);
228                }
229                None => {
230                    self.routes
231                        .insert(item, (RefCell::new(Rc::clone(&node)), RefCell::new(node)));
232                }
233            }
234        }
235    }
236
237    /// Generate the prefix paths that end with the given item.
238    pub fn generate_prefix_path(&self, item: T) -> Vec<Vec<RcNode<T>>> {
239        let mut cur_end_node = Rc::clone(&self.routes.get(&item).unwrap().0.borrow());
240        let mut paths = vec![];
241        loop {
242            let mut cur_node = Rc::clone(&cur_end_node);
243            let mut path = vec![Rc::clone(&cur_node)];
244            while let Some(parent_node) = cur_node.parent() {
245                if parent_node.is_root() {
246                    break;
247                }
248                path.push(Rc::clone(&parent_node));
249                cur_node = parent_node;
250            }
251            path.reverse();
252            paths.push(path);
253            match cur_end_node.neighbor() {
254                Some(neighbor_node) => cur_end_node = neighbor_node,
255                None => break,
256            }
257        }
258        paths
259    }
260
261    /// Get all nodes that holds the given item.
262    pub fn get_all_nodes(&self, item: T) -> Vec<RcNode<T>> {
263        match self.routes.get(&item) {
264            None => vec![],
265            Some((head_node, _)) => {
266                let mut nodes = vec![Rc::clone(&head_node.borrow())];
267                let mut cur_node = Rc::clone(&head_node.borrow());
268                while let Some(neighbor_node) = cur_node.neighbor() {
269                    nodes.push(Rc::clone(&neighbor_node));
270                    cur_node = neighbor_node;
271                }
272                nodes
273            }
274        }
275    }
276
277    /// Get all nodes with the given item.
278    pub fn get_all_items_nodes(&self) -> Vec<(T, Vec<RcNode<T>>)> {
279        let mut items_nodes = vec![];
280        for (item, _) in self.routes.iter() {
281            items_nodes.push((*item, self.get_all_nodes(*item)));
282        }
283        items_nodes
284    }
285
286    #[allow(dead_code)]
287    // [W.I.P] Prune the tree to reduce the search space.
288    fn prune(&self) {
289        let items_nodes = self.get_all_items_nodes();
290        for (item, nodes) in items_nodes.iter() {
291            if nodes.len() == 1 {
292                continue;
293            }
294            // Find all paths this item belongs to.
295            let mut all_paths = Vec::with_capacity(nodes.len());
296            let mut leaf_node_count = Vec::with_capacity(nodes.len());
297            for node in nodes.iter() {
298                if !node.is_leaf() {
299                    continue;
300                }
301                leaf_node_count.push(node.count());
302                let mut path = vec![];
303                let mut cur_node = Rc::clone(node);
304                while !cur_node.is_root() {
305                    path.push(Rc::clone(&cur_node));
306                    let parent_node = cur_node.parent().unwrap();
307                    cur_node = Rc::clone(&parent_node);
308                }
309                path.push(cur_node);
310                path.reverse();
311                all_paths.push(path);
312            }
313            if all_paths.len() < 2 {
314                continue;
315            }
316            // Find the common ancestor for all paths.
317            let mut common_ancestor_index = 0;
318            let mut common_ancestor = None;
319            for (index, node) in all_paths[0].iter().enumerate() {
320                let mut is_ancestor = true;
321                for path in all_paths.iter().skip(1) {
322                    let cur_node = Rc::clone(&path[index]);
323                    if cur_node != Rc::clone(node) {
324                        is_ancestor = false;
325                        break;
326                    }
327                }
328                if !is_ancestor {
329                    break;
330                }
331                common_ancestor_index = index;
332                common_ancestor = Some(node);
333            }
334            // Prune nodes which start from the common ancestor.
335            for (path_index, path) in all_paths.iter().enumerate() {
336                for node in path.iter().skip(common_ancestor_index + 1) {
337                    if node.count() <= leaf_node_count[path_index] {
338                        let parent_node = node.parent().unwrap();
339                        parent_node.remove_child(Rc::clone(node));
340                        break;
341                    }
342                }
343            }
344            common_ancestor
345                .unwrap()
346                .add_child(Node::new_rc(Some(*item), leaf_node_count.iter().sum()));
347        }
348    }
349
350    #[allow(dead_code)]
351    /// Print out the tree.
352    pub fn print(&self) {
353        println!("Tree:");
354        self.root_node.borrow().print(1);
355        println!("Routes:");
356        for (item, _) in self.routes.iter() {
357            println!("Item: {:?}", *item);
358            for node in self.get_all_nodes(*item).iter() {
359                println!("{:?}", Rc::into_raw(Rc::clone(node)));
360                println!("<{:?} {}>", node.item, node.count.get());
361            }
362        }
363    }
364}