competitive_programming_rs/graph/
re_rooting.rs

1pub struct ReRooting<T, Identity, Merge, AddRoot> {
2    dp: Vec<Vec<T>>,
3    ans: Vec<T>,
4    graph: Vec<Vec<usize>>,
5    identity: Identity,
6    merge: Merge,
7    add_root: AddRoot,
8}
9
10impl<T, Identity, Merge, AddRoot> ReRooting<T, Identity, Merge, AddRoot>
11where
12    T: Clone,
13    Identity: Fn() -> T,
14    Merge: Fn(T, T) -> T,
15    AddRoot: Fn(T) -> T,
16{
17    pub fn new(n: usize, identity: Identity, merge: Merge, add_root: AddRoot) -> Self {
18        Self {
19            dp: vec![vec![]; n],
20            ans: vec![identity(); n],
21            graph: vec![vec![]; n],
22            identity,
23            merge,
24            add_root,
25        }
26    }
27    pub fn add_edge(&mut self, a: usize, b: usize) {
28        self.graph[a].push(b);
29    }
30    pub fn build(&mut self) {
31        self.dfs(0, 0);
32        self.dfs2(0, 0, (self.identity)());
33    }
34
35    fn dfs(&mut self, v: usize, p: usize) -> T {
36        let mut sum = (self.identity)();
37        let deg = self.graph[v].len();
38        self.dp[v] = vec![(self.identity)(); deg];
39        let next = self.graph[v].clone();
40        for (i, next) in next.into_iter().enumerate() {
41            if next == p {
42                continue;
43            }
44            let t = self.dfs(next, v);
45            self.dp[v][i] = t.clone();
46            sum = (self.merge)(sum, t);
47        }
48        (self.add_root)(sum)
49    }
50    fn dfs2(&mut self, v: usize, p: usize, dp_p: T) {
51        for (i, &next) in self.graph[v].iter().enumerate() {
52            if next == p {
53                self.dp[v][i] = dp_p.clone();
54            }
55        }
56
57        let deg = self.graph[v].len();
58        let mut dp_l = vec![(self.identity)(); deg + 1];
59        let mut dp_r = vec![(self.identity)(); deg + 1];
60        for i in 0..deg {
61            dp_l[i + 1] = (self.merge)(dp_l[i].clone(), self.dp[v][i].clone());
62        }
63        for i in (0..deg).rev() {
64            dp_r[i] = (self.merge)(dp_r[i + 1].clone(), self.dp[v][i].clone());
65        }
66
67        self.ans[v] = (self.add_root)(dp_l[deg].clone());
68
69        let next = self.graph[v].clone();
70        for (i, next) in next.into_iter().enumerate() {
71            if next == p {
72                continue;
73            }
74            self.dfs2(
75                next,
76                v,
77                (self.add_root)((self.merge)(dp_l[i].clone(), dp_r[i + 1].clone())),
78            );
79        }
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use crate::graph::re_rooting::ReRooting;
86
87    #[test]
88    fn test_re_rooting() {
89        fn comb(n: usize, k: usize) -> usize {
90            let mut ans = 1;
91            for i in 0..k {
92                ans *= n - i;
93                ans /= i + 1;
94            }
95            ans
96        }
97        let merge = |e1: Option<(i64, usize)>, e2: Option<(i64, usize)>| {
98            if let (Some((ans1, size1)), Some((ans2, size2))) = (e1, e2) {
99                let c = comb(size1 + size2, size1);
100                let ans = ans1 * ans2 * (c as i64);
101                Some((ans, size1 + size2))
102            } else {
103                e1.or(e2)
104            }
105        };
106        let add_root =
107            |e: Option<(i64, usize)>| e.map(|(ans, size)| (ans, size + 1)).or(Some((1, 1)));
108
109        let n = 8;
110        let mut graph = ReRooting::new(n, || None, merge, add_root);
111        let edges = vec![(1, 2), (2, 3), (3, 4), (3, 5), (3, 6), (6, 7), (6, 8)];
112        for (u, v) in edges {
113            let u = u - 1;
114            let v = v - 1;
115            graph.add_edge(u, v);
116            graph.add_edge(v, u);
117        }
118
119        graph.build();
120        assert_eq!(graph.ans[0].unwrap().0, 40);
121        assert_eq!(graph.ans[1].unwrap().0, 280);
122        assert_eq!(graph.ans[2].unwrap().0, 840);
123        assert_eq!(graph.ans[3].unwrap().0, 120);
124        assert_eq!(graph.ans[4].unwrap().0, 120);
125        assert_eq!(graph.ans[5].unwrap().0, 504);
126        assert_eq!(graph.ans[6].unwrap().0, 72);
127        assert_eq!(graph.ans[7].unwrap().0, 72);
128    }
129}