competitive_programming_rs/graph/
re_rooting.rs
1pub struct ReRooting<T, Identity, Merge, AddRoot> {
2 dp: Vec<Vec<T>>,
3 ans: Vec<T>,
4 graph: Vec<Vec<usize>>,
5 identity: Identity,
6 merge: Merge,
7 add_root: AddRoot,
8}
9
10impl<T, Identity, Merge, AddRoot> ReRooting<T, Identity, Merge, AddRoot>
11where
12 T: Clone,
13 Identity: Fn() -> T,
14 Merge: Fn(T, T) -> T,
15 AddRoot: Fn(T) -> T,
16{
17 pub fn new(n: usize, identity: Identity, merge: Merge, add_root: AddRoot) -> Self {
18 Self {
19 dp: vec![vec![]; n],
20 ans: vec![identity(); n],
21 graph: vec![vec![]; n],
22 identity,
23 merge,
24 add_root,
25 }
26 }
27 pub fn add_edge(&mut self, a: usize, b: usize) {
28 self.graph[a].push(b);
29 }
30 pub fn build(&mut self) {
31 self.dfs(0, 0);
32 self.dfs2(0, 0, (self.identity)());
33 }
34
35 fn dfs(&mut self, v: usize, p: usize) -> T {
36 let mut sum = (self.identity)();
37 let deg = self.graph[v].len();
38 self.dp[v] = vec![(self.identity)(); deg];
39 let next = self.graph[v].clone();
40 for (i, next) in next.into_iter().enumerate() {
41 if next == p {
42 continue;
43 }
44 let t = self.dfs(next, v);
45 self.dp[v][i] = t.clone();
46 sum = (self.merge)(sum, t);
47 }
48 (self.add_root)(sum)
49 }
50 fn dfs2(&mut self, v: usize, p: usize, dp_p: T) {
51 for (i, &next) in self.graph[v].iter().enumerate() {
52 if next == p {
53 self.dp[v][i] = dp_p.clone();
54 }
55 }
56
57 let deg = self.graph[v].len();
58 let mut dp_l = vec![(self.identity)(); deg + 1];
59 let mut dp_r = vec![(self.identity)(); deg + 1];
60 for i in 0..deg {
61 dp_l[i + 1] = (self.merge)(dp_l[i].clone(), self.dp[v][i].clone());
62 }
63 for i in (0..deg).rev() {
64 dp_r[i] = (self.merge)(dp_r[i + 1].clone(), self.dp[v][i].clone());
65 }
66
67 self.ans[v] = (self.add_root)(dp_l[deg].clone());
68
69 let next = self.graph[v].clone();
70 for (i, next) in next.into_iter().enumerate() {
71 if next == p {
72 continue;
73 }
74 self.dfs2(
75 next,
76 v,
77 (self.add_root)((self.merge)(dp_l[i].clone(), dp_r[i + 1].clone())),
78 );
79 }
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use crate::graph::re_rooting::ReRooting;
86
87 #[test]
88 fn test_re_rooting() {
89 fn comb(n: usize, k: usize) -> usize {
90 let mut ans = 1;
91 for i in 0..k {
92 ans *= n - i;
93 ans /= i + 1;
94 }
95 ans
96 }
97 let merge = |e1: Option<(i64, usize)>, e2: Option<(i64, usize)>| {
98 if let (Some((ans1, size1)), Some((ans2, size2))) = (e1, e2) {
99 let c = comb(size1 + size2, size1);
100 let ans = ans1 * ans2 * (c as i64);
101 Some((ans, size1 + size2))
102 } else {
103 e1.or(e2)
104 }
105 };
106 let add_root =
107 |e: Option<(i64, usize)>| e.map(|(ans, size)| (ans, size + 1)).or(Some((1, 1)));
108
109 let n = 8;
110 let mut graph = ReRooting::new(n, || None, merge, add_root);
111 let edges = vec![(1, 2), (2, 3), (3, 4), (3, 5), (3, 6), (6, 7), (6, 8)];
112 for (u, v) in edges {
113 let u = u - 1;
114 let v = v - 1;
115 graph.add_edge(u, v);
116 graph.add_edge(v, u);
117 }
118
119 graph.build();
120 assert_eq!(graph.ans[0].unwrap().0, 40);
121 assert_eq!(graph.ans[1].unwrap().0, 280);
122 assert_eq!(graph.ans[2].unwrap().0, 840);
123 assert_eq!(graph.ans[3].unwrap().0, 120);
124 assert_eq!(graph.ans[4].unwrap().0, 120);
125 assert_eq!(graph.ans[5].unwrap().0, 504);
126 assert_eq!(graph.ans[6].unwrap().0, 72);
127 assert_eq!(graph.ans[7].unwrap().0, 72);
128 }
129}