use crate::error::{MinCutError, Result};
use std::collections::HashMap;
pub type NodeId = u64;
#[derive(Debug, Clone)]
struct SplayNode {
id: NodeId,
parent: Option<usize>,
left: Option<usize>,
right: Option<usize>,
path_parent: Option<usize>,
size: usize,
value: f64,
path_aggregate: f64,
reversed: bool,
}
impl SplayNode {
#[inline]
fn new(id: NodeId, value: f64) -> Self {
Self {
id,
parent: None,
left: None,
right: None,
path_parent: None,
size: 1,
value,
path_aggregate: value,
reversed: false,
}
}
#[inline(always)]
fn is_root(&self, nodes: &[SplayNode]) -> bool {
if let Some(p) = self.parent {
let parent = &nodes[p];
parent.left != Some(self.id as usize) && parent.right != Some(self.id as usize)
} else {
true
}
}
}
pub struct LinkCutTree {
nodes: Vec<SplayNode>,
id_to_index: HashMap<NodeId, usize>,
index_to_id: Vec<NodeId>,
root_cache: HashMap<usize, usize>,
}
impl LinkCutTree {
#[inline]
pub fn new() -> Self {
Self {
nodes: Vec::new(),
id_to_index: HashMap::new(),
index_to_id: Vec::new(),
root_cache: HashMap::new(),
}
}
#[inline]
pub fn with_capacity(n: usize) -> Self {
Self {
nodes: Vec::with_capacity(n),
id_to_index: HashMap::with_capacity(n),
index_to_id: Vec::with_capacity(n),
root_cache: HashMap::with_capacity(n / 4), }
}
#[inline]
pub fn make_tree(&mut self, id: NodeId, value: f64) -> usize {
let index = self.nodes.len();
self.nodes.push(SplayNode::new(id, value));
self.id_to_index.insert(id, index);
self.index_to_id.push(id);
index
}
#[inline]
fn get_index(&self, id: NodeId) -> Result<usize> {
self.id_to_index
.get(&id)
.copied()
.ok_or_else(|| self.invalid_vertex_error(id))
}
#[cold]
#[inline(never)]
fn invalid_vertex_error(&self, id: NodeId) -> MinCutError {
MinCutError::InvalidVertex(id)
}
pub fn link(&mut self, u: NodeId, v: NodeId) -> Result<()> {
let u_idx = self.get_index(u)?;
let v_idx = self.get_index(v)?;
if self.connected(u, v) {
return Err(self.already_connected_error());
}
self.access(u_idx);
self.access(v_idx);
self.nodes[u_idx].left = Some(v_idx);
self.nodes[v_idx].parent = Some(u_idx);
self.pull_up(u_idx);
self.invalidate_cache(u_idx);
self.invalidate_cache(v_idx);
Ok(())
}
#[cold]
#[inline(never)]
fn already_connected_error(&self) -> MinCutError {
MinCutError::InternalError("Nodes are already in the same tree".to_string())
}
pub fn cut(&mut self, v: NodeId) -> Result<()> {
let v_idx = self.get_index(v)?;
self.access(v_idx);
if let Some(left_idx) = self.nodes[v_idx].left {
self.nodes[v_idx].left = None;
self.nodes[left_idx].parent = None;
self.pull_up(v_idx);
self.invalidate_cache(v_idx);
self.invalidate_cache(left_idx);
Ok(())
} else {
Err(self.already_root_error())
}
}
#[cold]
#[inline(never)]
fn already_root_error(&self) -> MinCutError {
MinCutError::InternalError("Node is already a root".to_string())
}
#[inline]
pub fn find_root(&mut self, v: NodeId) -> Result<NodeId> {
let v_idx = self.get_index(v)?;
if let Some(&cached_root) = self.root_cache.get(&v_idx) {
if self.verify_root_cache(v_idx, cached_root) {
return Ok(self.nodes[cached_root].id);
}
}
self.access(v_idx);
let mut current = v_idx;
while let Some(left) = self.nodes[current].left {
self.push_down(current);
current = left;
}
self.splay(current);
self.root_cache.insert(v_idx, current);
Ok(self.nodes[current].id)
}
#[inline]
fn verify_root_cache(&self, _node_idx: usize, cached_root: usize) -> bool {
cached_root < self.nodes.len()
}
#[inline]
fn invalidate_cache(&mut self, root_idx: usize) {
self.root_cache.retain(|_, &mut cached| cached != root_idx);
}
#[inline]
pub fn connected(&mut self, u: NodeId, v: NodeId) -> bool {
if let (Ok(u_idx), Ok(v_idx)) = (self.get_index(u), self.get_index(v)) {
if u_idx == v_idx {
return true;
}
self.access(u_idx);
self.access(v_idx);
self.find_ancestor_root(u_idx) == self.find_ancestor_root(v_idx)
} else {
false
}
}
#[inline]
pub fn path_aggregate(&mut self, v: NodeId) -> Result<f64> {
let v_idx = self.get_index(v)?;
self.access(v_idx);
Ok(self.nodes[v_idx].path_aggregate)
}
#[inline]
pub fn update_value(&mut self, v: NodeId, value: f64) -> Result<()> {
let v_idx = self.get_index(v)?;
self.nodes[v_idx].value = value;
self.pull_up(v_idx);
Ok(())
}
pub fn lca(&mut self, u: NodeId, v: NodeId) -> Result<NodeId> {
let u_idx = self.get_index(u)?;
let v_idx = self.get_index(v)?;
self.access(u_idx);
let lca_idx = self.access_with_lca(v_idx);
Ok(self.nodes[lca_idx].id)
}
#[inline]
pub fn len(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
#[inline]
fn access(&mut self, v: usize) {
self.splay(v);
if let Some(right_idx) = self.nodes[v].right {
self.nodes[right_idx].path_parent = Some(v);
self.nodes[right_idx].parent = None;
}
self.nodes[v].right = None;
self.pull_up(v);
let mut current = v;
while let Some(pp) = self.nodes[current].path_parent {
self.splay(pp);
if let Some(old_right) = self.nodes[pp].right {
self.nodes[old_right].path_parent = Some(pp);
self.nodes[old_right].parent = None;
}
self.nodes[pp].right = Some(current);
self.nodes[current].parent = Some(pp);
self.nodes[current].path_parent = None;
self.pull_up(pp);
current = pp;
}
self.splay(v);
}
#[inline]
fn access_with_lca(&mut self, v: usize) -> usize {
self.splay(v);
if let Some(right_idx) = self.nodes[v].right {
self.nodes[right_idx].path_parent = Some(v);
self.nodes[right_idx].parent = None;
}
self.nodes[v].right = None;
self.pull_up(v);
let mut lca = v;
let mut current = v;
while let Some(pp) = self.nodes[current].path_parent {
lca = pp;
self.splay(pp);
if let Some(old_right) = self.nodes[pp].right {
self.nodes[old_right].path_parent = Some(pp);
self.nodes[old_right].parent = None;
}
self.nodes[pp].right = Some(current);
self.nodes[current].parent = Some(pp);
self.nodes[current].path_parent = None;
self.pull_up(pp);
current = pp;
}
self.splay(v);
lca
}
#[inline]
fn splay(&mut self, x: usize) {
while !self.nodes[x].is_root(&self.nodes) {
let p = self.nodes[x].parent.unwrap();
if self.nodes[p].is_root(&self.nodes) {
self.push_down(p);
self.push_down(x);
self.rotate(x);
} else {
let g = self.nodes[p].parent.unwrap();
self.push_down(g);
self.push_down(p);
self.push_down(x);
let x_is_left = self.nodes[p].left == Some(x);
let p_is_left = self.nodes[g].left == Some(p);
if x_is_left == p_is_left {
self.rotate(p);
self.rotate(x);
} else {
self.rotate(x);
self.rotate(x);
}
}
}
self.push_down(x);
}
#[inline]
fn rotate(&mut self, x: usize) {
let p = self.nodes[x].parent.unwrap();
let g = self.nodes[p].parent;
let pp = self.nodes[p].path_parent;
let x_is_left = self.nodes[p].left == Some(x);
if x_is_left {
let b = self.nodes[x].right;
self.nodes[p].left = b;
if let Some(b_idx) = b {
self.nodes[b_idx].parent = Some(p);
}
self.nodes[x].right = Some(p);
} else {
let b = self.nodes[x].left;
self.nodes[p].right = b;
if let Some(b_idx) = b {
self.nodes[b_idx].parent = Some(p);
}
self.nodes[x].left = Some(p);
}
self.nodes[p].parent = Some(x);
self.nodes[x].parent = g;
if let Some(g_idx) = g {
if self.nodes[g_idx].left == Some(p) {
self.nodes[g_idx].left = Some(x);
} else if self.nodes[g_idx].right == Some(p) {
self.nodes[g_idx].right = Some(x);
}
}
self.nodes[x].path_parent = pp;
self.nodes[p].path_parent = None;
self.pull_up(p);
self.pull_up(x);
}
#[inline(always)]
fn push_down(&mut self, x: usize) {
if !self.nodes[x].reversed {
return;
}
let left = self.nodes[x].left;
let right = self.nodes[x].right;
self.nodes[x].left = right;
self.nodes[x].right = left;
if let Some(left_idx) = left {
self.nodes[left_idx].reversed ^= true;
}
if let Some(right_idx) = right {
self.nodes[right_idx].reversed ^= true;
}
self.nodes[x].reversed = false;
}
#[inline(always)]
fn pull_up(&mut self, x: usize) {
let mut size = 1;
let mut aggregate = self.nodes[x].value;
if let Some(left_idx) = self.nodes[x].left {
size += self.nodes[left_idx].size;
aggregate = aggregate.min(self.nodes[left_idx].path_aggregate);
}
if let Some(right_idx) = self.nodes[x].right {
size += self.nodes[right_idx].size;
aggregate = aggregate.min(self.nodes[right_idx].path_aggregate);
}
self.nodes[x].size = size;
self.nodes[x].path_aggregate = aggregate;
}
#[inline]
fn find_ancestor_root(&self, mut x: usize) -> usize {
while let Some(p) = self.nodes[x].parent {
x = p;
}
while let Some(pp) = self.nodes[x].path_parent {
x = pp;
}
x
}
pub fn bulk_link(&mut self, edges: &[(NodeId, NodeId)]) -> Result<()> {
for &(u, v) in edges {
self.get_index(u)?;
self.get_index(v)?;
}
for &(u, v) in edges {
self.link(u, v)?;
}
self.root_cache.clear();
Ok(())
}
pub fn bulk_update(&mut self, updates: &[(NodeId, f64)]) -> Result<()> {
for &(id, value) in updates {
let idx = self.get_index(id)?;
self.nodes[idx].value = value;
}
for &(id, _) in updates {
let idx = self.get_index(id)?;
self.pull_up(idx);
}
Ok(())
}
}
impl Default for LinkCutTree {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_make_tree() {
let mut lct = LinkCutTree::new();
let idx0 = lct.make_tree(0, 1.0);
let idx1 = lct.make_tree(1, 2.0);
assert_eq!(idx0, 0);
assert_eq!(idx1, 1);
assert_eq!(lct.len(), 2);
assert_eq!(lct.nodes[idx0].value, 1.0);
assert_eq!(lct.nodes[idx1].value, 2.0);
}
#[test]
fn test_link_and_find_root() {
let mut lct = LinkCutTree::new();
lct.make_tree(0, 1.0);
lct.make_tree(1, 2.0);
lct.make_tree(2, 3.0);
lct.link(0, 1).unwrap();
lct.link(1, 2).unwrap();
assert_eq!(lct.find_root(0).unwrap(), 2);
assert_eq!(lct.find_root(1).unwrap(), 2);
assert_eq!(lct.find_root(2).unwrap(), 2);
}
#[test]
fn test_connected() {
let mut lct = LinkCutTree::new();
lct.make_tree(0, 1.0);
lct.make_tree(1, 2.0);
lct.make_tree(2, 3.0);
lct.make_tree(3, 4.0);
lct.link(0, 1).unwrap();
lct.link(1, 2).unwrap();
assert!(lct.connected(0, 1));
assert!(lct.connected(0, 2));
assert!(lct.connected(1, 2));
assert!(!lct.connected(0, 3));
assert!(!lct.connected(2, 3));
}
#[test]
fn test_cut() {
let mut lct = LinkCutTree::new();
lct.make_tree(0, 1.0);
lct.make_tree(1, 2.0);
lct.make_tree(2, 3.0);
lct.link(0, 1).unwrap();
lct.link(1, 2).unwrap();
assert!(lct.connected(0, 2));
lct.cut(1).unwrap();
assert!(!lct.connected(0, 2));
assert!(lct.connected(0, 1));
assert_eq!(lct.find_root(0).unwrap(), 1);
assert_eq!(lct.find_root(2).unwrap(), 2);
}
#[test]
fn test_path_aggregate() {
let mut lct = LinkCutTree::new();
lct.make_tree(0, 5.0);
lct.make_tree(1, 3.0);
lct.make_tree(2, 7.0);
lct.make_tree(3, 2.0);
lct.link(0, 1).unwrap();
lct.link(1, 2).unwrap();
lct.link(2, 3).unwrap();
let agg = lct.path_aggregate(0).unwrap();
assert_eq!(agg, 2.0);
let agg = lct.path_aggregate(1).unwrap();
assert_eq!(agg, 2.0);
let agg = lct.path_aggregate(3).unwrap();
assert_eq!(agg, 2.0);
}
#[test]
fn test_update_value() {
let mut lct = LinkCutTree::new();
lct.make_tree(0, 5.0);
lct.make_tree(1, 3.0);
lct.link(0, 1).unwrap();
lct.update_value(0, 1.0).unwrap();
let agg = lct.path_aggregate(0).unwrap();
assert_eq!(agg, 1.0);
}
#[test]
fn test_lca() {
let mut lct = LinkCutTree::new();
for i in 0..5 {
lct.make_tree(i, i as f64);
}
lct.link(0, 1).unwrap();
lct.link(1, 3).unwrap();
lct.link(2, 4).unwrap();
lct.link(3, 4).unwrap();
assert!(lct.connected(0, 1), "0 and 1 should be connected");
assert!(lct.connected(0, 3), "0 and 3 should be connected");
assert!(lct.connected(0, 4), "0 and 4 should be connected");
assert!(lct.connected(2, 4), "2 and 4 should be connected");
let lca = lct.lca(0, 2).unwrap();
assert_eq!(lca, 4);
let lca = lct.lca(0, 1).unwrap();
assert_eq!(lca, 1);
let lca = lct.lca(0, 3).unwrap();
assert_eq!(lca, 3);
}
#[test]
fn test_complex_operations() {
let mut lct = LinkCutTree::with_capacity(10);
for i in 0..10 {
lct.make_tree(i, i as f64 * 0.5);
}
for i in 0..4 {
lct.link(i, i + 1).unwrap();
}
lct.link(5, 6).unwrap();
lct.link(6, 7).unwrap();
assert!(lct.connected(0, 4));
assert!(lct.connected(5, 7));
assert!(!lct.connected(0, 5));
lct.cut(2).unwrap();
assert!(!lct.connected(0, 4));
assert!(lct.connected(0, 2));
lct.link(4, 7).unwrap();
assert!(
lct.connected(4, 7),
"4 and 7 should be connected after link"
);
assert!(lct.connected(3, 7), "3 and 7 should be connected through 4");
}
#[test]
fn test_error_cases() {
let mut lct = LinkCutTree::new();
lct.make_tree(0, 1.0);
lct.make_tree(1, 2.0);
lct.link(0, 1).unwrap();
assert!(lct.link(0, 1).is_err());
assert!(lct.cut(1).is_err());
assert!(lct.find_root(99).is_err());
assert!(lct.link(0, 99).is_err());
}
#[test]
fn test_large_tree() {
let mut lct = LinkCutTree::with_capacity(1000);
for i in 0..1000 {
lct.make_tree(i, i as f64);
}
for i in 0..999 {
lct.link(i, i + 1).unwrap();
}
assert_eq!(lct.find_root(0).unwrap(), 999);
assert_eq!(lct.find_root(500).unwrap(), 999);
let agg = lct.path_aggregate(0).unwrap();
assert_eq!(agg, 0.0);
lct.cut(500).unwrap();
assert_eq!(lct.find_root(0).unwrap(), 500);
assert_eq!(lct.find_root(999).unwrap(), 999);
}
#[test]
fn test_multiple_forests() {
let mut lct = LinkCutTree::new();
for i in 0..9 {
lct.make_tree(i, i as f64);
}
lct.link(0, 1).unwrap();
lct.link(1, 2).unwrap();
lct.link(3, 4).unwrap();
lct.link(4, 5).unwrap();
lct.link(6, 7).unwrap();
lct.link(7, 8).unwrap();
assert_eq!(lct.find_root(0).unwrap(), 2);
assert_eq!(lct.find_root(3).unwrap(), 5);
assert_eq!(lct.find_root(6).unwrap(), 8);
assert!(!lct.connected(0, 3));
assert!(!lct.connected(3, 6));
assert!(!lct.connected(0, 6));
lct.link(2, 5).unwrap();
assert!(lct.connected(0, 5));
assert_eq!(lct.find_root(0).unwrap(), 5);
assert_eq!(lct.find_root(3).unwrap(), 5);
}
#[test]
fn test_bulk_operations() {
let mut lct = LinkCutTree::with_capacity(10);
for i in 0..10 {
lct.make_tree(i, i as f64);
}
let edges = vec![(0, 1), (1, 2), (2, 3)];
lct.bulk_link(&edges).unwrap();
assert!(lct.connected(0, 3));
let updates = vec![(0, 10.0), (1, 20.0), (2, 30.0)];
lct.bulk_update(&updates).unwrap();
assert_eq!(lct.nodes[0].value, 10.0);
assert_eq!(lct.nodes[1].value, 20.0);
assert_eq!(lct.nodes[2].value, 30.0);
}
#[test]
fn test_root_caching() {
let mut lct = LinkCutTree::with_capacity(100);
for i in 0..100 {
lct.make_tree(i, i as f64);
}
for i in 0..99 {
lct.link(i, i + 1).unwrap();
}
let root1 = lct.find_root(0).unwrap();
assert_eq!(root1, 99);
let root2 = lct.find_root(0).unwrap();
assert_eq!(root2, 99);
lct.cut(50).unwrap();
let root3 = lct.find_root(0).unwrap();
assert_eq!(root3, 50);
}
}