1use std::{
65 cell::Cell,
66 ops::{Index, IndexMut},
67};
68
69use bitvec::vec::BitVec;
70use thiserror::Error;
71
72use crate::half_edge::{
73 involution::Hedge,
74 subgraph::{Inclusion, SubGraph, SubGraphOps},
75};
76
77pub type ParentPointer = Hedge;
80
81#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
83#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
84#[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))]
85pub struct SetIndex(pub usize);
86
87impl From<usize> for SetIndex {
88 fn from(x: usize) -> Self {
89 SetIndex(x)
90 }
91}
92
93#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
98#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
99#[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))]
100pub enum UFNode {
101 Root { set_data_idx: SetIndex, rank: usize },
102 Child(ParentPointer),
103}
104
105#[derive(Debug, Clone, Error)]
106pub enum UnionFindError {
107 #[error("The set of bitvecs does not partion the elements")]
108 DoesNotPartion,
109 #[error("The set of bitvecs does not have the same length")]
110 LengthMismatch,
111}
112
113#[derive(Debug, Clone, PartialEq, Eq)]
114#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
115#[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))]
116pub struct UnionFind<U> {
124 pub nodes: Vec<Cell<UFNode>>,
126
127 pub(crate) set_data: Vec<SetData<U>>,
130 }
132
133#[derive(Debug, Copy, Clone, PartialEq, Eq)]
134#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
135#[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))]
136pub struct SetData<U> {
137 pub(crate) root_pointer: ParentPointer,
138 pub(crate) data: Option<U>,
139}
140
141#[derive(Debug, Copy, Clone, PartialEq, Eq)]
142#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
143#[cfg_attr(feature = "bincode", derive(bincode::Encode, bincode::Decode))]
144enum SetSplit {
145 Left,
146 FullyExtracted,
147 Split,
148 Unset,
149}
150
151pub fn left<E>(l: E, _: E) -> E {
152 l
153}
154
155pub fn right<E>(_: E, r: E) -> E {
156 r
157}
158
159impl<U> UnionFind<U> {
160 pub fn n_elements(&self) -> usize {
161 self.nodes.len()
162 }
163
164 pub fn swap_set_data(&mut self, a: SetIndex, b: SetIndex) {
165 self.set_data.swap(a.0, b.0);
166 for n in &mut self.nodes {
167 if let UFNode::Root { set_data_idx, .. } = n.get_mut() {
168 if *set_data_idx == a {
169 *set_data_idx = a;
170 } else if *set_data_idx == b {
171 *set_data_idx = b;
172 }
173 }
174 }
175 }
176
177 #[allow(clippy::manual_map)]
179 pub fn extract<O>(
180 &mut self,
181 mut part: impl FnMut(ParentPointer) -> bool,
182 mut split_set_data: impl FnMut(&U) -> O,
183 mut extract_set_data: impl FnMut(U) -> O,
184 ) -> UnionFind<O> {
185 let mut left = Hedge(0);
186 let mut extracted = Hedge(self.nodes.len());
187 let mut set_split = vec![SetSplit::Unset; self.set_data.len()];
188
189 while left < extracted {
190 if !part(left) {
191 left.0 += 1;
193 let root = self.find_data_index(left);
194 match set_split[root.0] {
195 SetSplit::Left => {}
196 SetSplit::Unset => set_split[root.0] = SetSplit::Left,
197 SetSplit::FullyExtracted => set_split[root.0] = SetSplit::Split,
198 SetSplit::Split => {}
199 }
200 } else {
201 extracted.0 -= 1;
203 let root = self.find_data_index(extracted);
204 if !part(extracted) {
205 match set_split[root.0] {
207 SetSplit::Left => {}
208 SetSplit::Unset => set_split[root.0] = SetSplit::Left,
209 SetSplit::FullyExtracted => set_split[root.0] = SetSplit::Split,
210 SetSplit::Split => {}
211 }
212 self.swap(left, extracted);
213 left.0 += 1;
214 } else {
215 match set_split[root.0] {
216 SetSplit::Left => set_split[root.0] = SetSplit::Split,
217 SetSplit::Unset => set_split[root.0] = SetSplit::FullyExtracted,
218 SetSplit::FullyExtracted => {}
219 SetSplit::Split => {}
220 }
221 }
222 }
223 }
224
225 let mut left_nodes = SetIndex(0);
226 let mut extracted_nodes = SetIndex(self.n_sets());
227 while left_nodes < extracted_nodes {
228 if let SetSplit::Left = set_split[left_nodes.0] {
229 left_nodes.0 += 1;
231 } else {
232 extracted_nodes.0 -= 1;
234 if let SetSplit::Left = set_split[extracted_nodes.0] {
235 self.swap_set_data(left_nodes, extracted_nodes);
237 left_nodes.0 += 1;
238 }
239 }
240 }
241 let mut overlapping_nodes = left_nodes;
242 let mut non_overlapping_extracted = SetIndex(self.n_sets());
243
244 while overlapping_nodes < non_overlapping_extracted {
245 if let SetSplit::Split = set_split[overlapping_nodes.0] {
246 overlapping_nodes.0 += 1;
248 } else {
249 non_overlapping_extracted.0 -= 1;
251 if let SetSplit::Split = set_split[non_overlapping_extracted.0] {
252 self.swap_set_data(overlapping_nodes, non_overlapping_extracted);
254 overlapping_nodes.0 += 1;
255 }
256 }
257 }
258
259 let mut extracted_nodes = self.nodes.split_off(overlapping_nodes.0);
260 let extracted_data: Vec<_> = self
261 .set_data
262 .split_off(overlapping_nodes.0)
263 .into_iter()
264 .map(|a| SetData {
265 data: if let Some(data) = a.data {
266 Some(extract_set_data(data))
267 } else {
268 None
269 },
270 root_pointer: a.root_pointer,
271 })
272 .collect();
273
274 let mut overlapping_data = vec![];
275
276 for i in (left_nodes.0)..(overlapping_nodes.0) {
277 let data = Some(split_set_data(&self[SetIndex(i)]));
278 let root_pointer = self[&SetIndex(i)].root_pointer;
279
280 overlapping_data.push(SetData { root_pointer, data })
281 }
282
283 overlapping_data.extend(extracted_data);
284
285 for n in &mut extracted_nodes {
286 if let UFNode::Root { set_data_idx, .. } = n.get_mut() {
287 set_data_idx.0 -= left_nodes.0
288 }
289 }
290
291 UnionFind {
292 nodes: extracted_nodes,
293 set_data: overlapping_data,
294 }
295 }
296
297 pub fn swap(&mut self, a: ParentPointer, b: ParentPointer) {
298 match (self.is_child(a), self.is_child(b)) {
299 (true, true) => {}
300 (true, false) => {
301 for n in &mut self.nodes {
302 if let UFNode::Child(pp) = n.get_mut() {
303 if *pp == a {
304 *pp = b;
305 }
306 }
307 }
308 }
309 (false, true) => {
310 for n in &mut self.nodes {
311 if let UFNode::Child(pp) = n.get_mut() {
312 if *pp == b {
313 *pp = a;
314 }
315 }
316 }
317 }
318 (false, false) => {
319 for n in &mut self.nodes {
320 if let UFNode::Child(pp) = n.get_mut() {
321 if *pp == a {
322 *pp = b;
323 } else if *pp == b {
324 *pp = a;
325 }
326 }
327 }
328 }
329 }
330 self.nodes.swap(a.0, b.0);
331 }
332
333 pub fn n_sets(&self) -> usize {
334 self.set_data.len()
335 }
336
337 pub fn extend(&mut self, mut other: Self) {
338 let shift_nodes = self.set_data.len();
339 let shift_set = self.nodes.len();
340 other.nodes.iter_mut().for_each(|a| match a.get_mut() {
341 UFNode::Child(a) => a.0 += shift_set,
342 UFNode::Root { set_data_idx, .. } => set_data_idx.0 += shift_nodes,
343 });
344 other
345 .set_data
346 .iter_mut()
347 .for_each(|a| a.root_pointer.0 += shift_set);
348 self.set_data.extend(other.set_data);
349 }
350
351 pub fn iter_set_data(&self) -> impl Iterator<Item = (SetIndex, &U)> {
352 self.set_data
353 .iter()
354 .enumerate()
355 .map(|(i, d)| (SetIndex(i), d.data.as_ref().unwrap()))
356 }
357
358 pub fn iter_set_data_mut(&mut self) -> impl Iterator<Item = (SetIndex, &mut U)> {
359 self.set_data
360 .iter_mut()
361 .enumerate()
362 .map(|(i, d)| (SetIndex(i), d.data.as_mut().unwrap()))
363 }
364
365 pub fn drain_set_data(self) -> impl Iterator<Item = (SetIndex, U)> {
366 self.set_data
367 .into_iter()
368 .enumerate()
369 .map(|(i, d)| (SetIndex(i), d.data.unwrap()))
370 }
371
372 pub fn from_bitvec_partition(bitvec_part: Vec<(U, BitVec)>) -> Result<Self, UnionFindError> {
373 let mut nodes = vec![];
374 let mut set_data = vec![];
375 let mut cover: Option<BitVec> = None;
376
377 for (d, set) in bitvec_part {
378 let len = set.len();
379 if let Some(c) = &mut cover {
380 if c.len() != len {
381 return Err(UnionFindError::LengthMismatch);
382 }
383 if c.intersects(&set) {
384 return Err(UnionFindError::DoesNotPartion);
385 }
386 c.union_with(&set);
387 } else {
388 cover = Some(BitVec::empty(len));
389 nodes = vec![None; len];
390 }
391 let mut first = None;
392 for i in set.included_iter() {
393 if let Some(root) = first {
394 nodes[i.0] = Some(Cell::new(UFNode::Child(root)))
395 } else {
396 first = Some(i);
397 nodes[i.0] = Some(Cell::new(UFNode::Root {
398 set_data_idx: SetIndex(set_data.len()),
399 rank: set.count_ones(),
400 }))
401 }
402 }
403 set_data.push(SetData {
404 root_pointer: first.unwrap(),
405 data: Some(d),
406 });
407 }
408 Ok(UnionFind {
409 nodes: nodes.into_iter().collect::<Option<_>>().unwrap(),
410 set_data,
411 })
412 }
413
414 pub fn new(associated: Vec<U>) -> Self {
417 let nodes = (0..associated.len())
420 .map(|i| {
421 Cell::new(UFNode::Root {
422 set_data_idx: SetIndex(i),
423 rank: 0,
424 })
425 })
426 .collect();
427
428 let set_data = associated
429 .into_iter()
430 .enumerate()
431 .map(|(i, d)| SetData {
432 root_pointer: Hedge(i),
433 data: Some(d),
434 })
435 .collect();
436
437 Self {
438 nodes,
440 set_data,
441 }
442 }
443
444 pub fn find(&self, x: ParentPointer) -> ParentPointer {
446 match self[&x].get() {
447 UFNode::Root { .. } => x,
448 UFNode::Child(parent) => {
449 let root = self.find(parent);
450 self[&x].set(UFNode::Child(root));
452 root
453 }
454 }
455 }
456
457 pub fn is_child(&self, x: ParentPointer) -> bool {
458 matches!(self[&x].get(), UFNode::Child(_))
459 }
465
466 pub fn find_data_index(&self, x: ParentPointer) -> SetIndex {
468 let root = self.find(x);
469 match self[&root].get() {
470 UFNode::Root { set_data_idx, .. } => set_data_idx,
471 UFNode::Child(_) => unreachable!("find always returns a root"),
472 }
473 }
474
475 pub fn find_data(&self, x: ParentPointer) -> &U {
478 &self[self.find_data_index(x)]
479 }
480
481 pub fn union<F>(&mut self, x: ParentPointer, y: ParentPointer, merge: F) -> ParentPointer
488 where
489 F: FnOnce(U, U) -> U,
490 {
491 let rx = self.find(x);
492 let ry = self.find(y);
493 if rx == ry {
494 return rx;
495 }
496
497 let (rank_x, data_x) = match self[&rx].get() {
499 UFNode::Root { rank, set_data_idx } => (rank, set_data_idx),
500 _ => unreachable!(),
501 };
502 let (rank_y, data_y) = match self[&ry].get() {
503 UFNode::Root { rank, set_data_idx } => (rank, set_data_idx),
504 _ => unreachable!(),
505 };
506
507 let (winner, loser, winner_data_idx, loser_data_idx, same_rank) = match rank_x.cmp(&rank_y)
508 {
509 std::cmp::Ordering::Less => (ry, rx, data_y, data_x, false),
510 std::cmp::Ordering::Greater => (rx, ry, data_x, data_y, false),
511 std::cmp::Ordering::Equal => (rx, ry, data_x, data_y, true),
512 };
513
514 if same_rank {
515 if let UFNode::Root { set_data_idx, rank } = self[&winner].get() {
516 self[&winner].set(UFNode::Root {
517 set_data_idx,
518 rank: rank + 1,
519 });
520 }
521 }
522
523 self[&loser].set(UFNode::Child(winner));
525
526 let winner_opt = self[&winner_data_idx].data.take();
529 let loser_opt = self[&loser_data_idx].data.take();
530
531 let merged = merge(
533 winner_opt.expect("winner has no data?"),
534 loser_opt.expect("loser has no data?"),
535 );
536 self[&winner_data_idx].data = Some(merged);
537
538 let last_idx = self.set_data.len() - 1;
540 if loser_data_idx.0 != last_idx {
541 self.set_data.swap(loser_data_idx.0, last_idx);
543
544 let swapped_node = self.set_data[loser_data_idx.0].root_pointer;
546 if let UFNode::Root { set_data_idx, rank } = self[&swapped_node].get() {
547 if set_data_idx.0 == last_idx {
550 self[&swapped_node].set(UFNode::Root {
551 set_data_idx: loser_data_idx,
552 rank,
553 });
554 }
555 }
556
557 if winner_data_idx.0 == last_idx {
559 if let UFNode::Root { set_data_idx, rank } = self[&winner].get() {
560 if set_data_idx.0 == last_idx {
561 self[&winner].set(UFNode::Root {
563 set_data_idx: loser_data_idx,
564 rank,
565 });
566 }
567 }
568 }
569 }
570 self.set_data.pop();
571
572 winner
573 }
574
575 pub fn map_set_data_of<F>(&mut self, x: ParentPointer, f: F)
577 where
578 F: FnOnce(&mut U),
579 {
580 let idx = self.find_data_index(x);
581 let data_ref = &mut self[idx]; f(data_ref);
583 }
584
585 pub fn map_set_data<V, F>(self, mut f: F) -> UnionFind<V>
586 where
587 F: FnMut(SetIndex, U) -> V,
588 {
589 UnionFind {
590 nodes: self.nodes,
591 set_data: self
592 .set_data
593 .into_iter()
594 .enumerate()
595 .map(|(i, a)| SetData {
596 data: Some(f(SetIndex(i), a.data.unwrap())),
597 root_pointer: a.root_pointer,
598 })
599 .collect(),
600 }
601 }
602
603 pub fn map_set_data_ref<'a, F, V>(&'a self, mut f: F) -> UnionFind<V>
604 where
605 F: FnMut(&'a U) -> V,
606 {
607 UnionFind {
608 nodes: self.nodes.clone(),
609 set_data: self
610 .set_data
611 .iter()
612 .map(|d| SetData {
613 root_pointer: d.root_pointer,
614 data: d.data.as_ref().map(&mut f),
615 })
616 .collect(),
617 }
618 }
619
620 pub fn map_set_data_ref_mut<'a, F, V>(&'a mut self, mut f: F) -> UnionFind<V>
621 where
622 F: FnMut(&'a mut U) -> V,
623 {
624 UnionFind {
625 nodes: self.nodes.clone(),
626 set_data: self
627 .set_data
628 .iter_mut()
629 .map(|d| SetData {
630 root_pointer: d.root_pointer,
631 data: d.data.as_mut().map(&mut f),
632 })
633 .collect(),
634 }
635 }
636
637 pub fn map_set_data_ref_result<'a, F, V, Er>(&'a self, mut f: F) -> Result<UnionFind<V>, Er>
638 where
639 F: FnMut(&'a U) -> Result<V, Er>,
640 {
641 let r: Result<Vec<_>, Er> = self
642 .set_data
643 .iter()
644 .map(|d| {
645 let data = d.data.as_ref().map(&mut f).transpose();
646 match data {
647 Ok(data) => Ok(SetData {
648 root_pointer: d.root_pointer,
649 data,
650 }),
651 Err(err) => Err(err),
652 }
653 })
654 .collect();
655 Ok(UnionFind {
656 nodes: self.nodes.clone(),
657 set_data: r?,
658 })
659 }
660
661 pub fn replace_set_data_of<F>(&mut self, x: ParentPointer, f: F)
663 where
664 F: FnOnce(U) -> U,
665 {
666 let idx = self.find_data_index(x);
667 let old_data = self[&idx].data.take().expect("no data to replace");
668 self[&idx].data.replace(f(old_data));
669 }
670
671 pub fn add_child(&mut self, set_id: SetIndex) -> ParentPointer {
672 let root = self[&set_id].root_pointer;
673 let h = Hedge(self.nodes.len());
674 self.nodes.push(Cell::new(UFNode::Child(root)));
675 h
676 }
677}
678
679impl<U> Index<SetIndex> for UnionFind<U> {
685 type Output = U;
686 fn index(&self, idx: SetIndex) -> &Self::Output {
687 self.set_data[idx.0]
688 .data
689 .as_ref()
690 .expect("no data in that slot!")
691 }
692}
693
694impl<U> IndexMut<SetIndex> for UnionFind<U> {
696 fn index_mut(&mut self, idx: SetIndex) -> &mut Self::Output {
697 self.set_data[idx.0]
698 .data
699 .as_mut()
700 .expect("no data in that slot!")
701 }
702}
703
704impl<U> Index<&SetIndex> for UnionFind<U> {
707 type Output = SetData<U>;
708 fn index(&self, idx: &SetIndex) -> &Self::Output {
709 &self.set_data[idx.0]
710 }
711}
712
713impl<U> IndexMut<&SetIndex> for UnionFind<U> {
715 fn index_mut(&mut self, idx: &SetIndex) -> &mut Self::Output {
716 &mut self.set_data[idx.0]
717 }
718}
719
720impl<U> Index<&ParentPointer> for UnionFind<U> {
739 type Output = Cell<UFNode>;
740 fn index(&self, idx: &ParentPointer) -> &Self::Output {
741 &self.nodes[idx.0]
742 }
743}
744
745#[cfg(test)]
747pub mod test;