use crate::extra::{ByRank, Extra, GrowableExtra};
use crate::mapping::{
GrowableIdentityMapping, GrowableMapping, Mapping, ParentMapping, RankMapping,
};
use crate::union::Union;
use std::cmp::Ordering;
use std::marker::PhantomData;
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct UnionFind<T, V, M, E = ()> {
parent: M,
extra: E,
phantom: PhantomData<(T, V)>,
}
#[derive(Debug, Error, PartialEq)]
pub enum NewUnionFindError<P, E> {
#[error("couldn't construct parent mapping")]
Parent(#[source] P),
#[error("couldn't construct extra mapping")]
Extra(#[source] E),
}
type NewUnionFindErrorSimple<T, V, M, E> =
NewUnionFindError<<M as ParentMapping<T>>::Err, <E as Extra<T, V>>::DefaultMappingErr>;
impl<T, V, M, E> UnionFind<T, V, M, E>
where
M: ParentMapping<T>,
T: Clone,
E: Extra<T, V>,
{
pub fn new(
elems: impl IntoIterator<Item = T> + Clone,
) -> Result<Self, NewUnionFindErrorSimple<T, V, M, E>> {
Ok(Self {
parent: M::identity_map(elems.clone())
.map_err(NewUnionFindErrorSimple::<T, V, M, E>::Parent)?,
extra: E::default_mapping(elems)
.map_err(NewUnionFindErrorSimple::<T, V, M, E>::Extra)?,
phantom: Default::default(),
})
}
}
impl<T: PartialEq, V, M: Mapping<T, T>, E> UnionFind<T, V, M, E> {
pub fn find(&self, elem: &T) -> Option<T>
where
T: Clone,
{
let parent = self.parent.get(elem)?.clone();
if &parent == elem {
Some(parent)
} else {
let new_parent = self.find(&parent)?;
Some(new_parent)
}
}
pub fn find_shorten(&mut self, elem: &T) -> Option<T>
where
T: Clone,
{
let parent = self.parent.get(elem)?.clone();
if &parent == elem {
Some(parent)
} else {
let new_parent = self.find_shorten(&parent)?;
self.parent.set(elem.clone(), new_parent.clone());
Some(new_parent)
}
}
}
#[derive(Error, Debug)]
pub enum UnionOrAddError<Err, T, V, M: GrowableMapping<T, T>, E: GrowableExtra<T, V>> {
#[error(transparent)]
AddError(AddErrorSimple<T, V, M, E>),
#[error("could not union elements")]
NotUnionable(Err),
}
impl<T: Clone + PartialEq, V, M, E> UnionFind<T, V, M, E>
where
M: GrowableIdentityMapping<T>,
E: GrowableExtra<T, V>,
V: Default,
{
pub fn find_or_add(&mut self, elem: &T) -> Result<T, AddErrorSimple<T, V, M, E>>
where
T: Clone,
{
match self.find(elem) {
Some(i) => Ok(i),
None => {
self.add(elem.clone())?;
Ok(elem.clone())
}
}
}
pub fn find_shorten_or_add(&mut self, elem: &T) -> Result<T, AddErrorSimple<T, V, M, E>>
where
T: Clone,
{
match self.find_shorten(elem) {
Some(i) => Ok(i),
None => {
self.add(elem.clone())?;
Ok(elem.clone())
}
}
}
pub fn union_by_or_add<U: Union<T>>(
&mut self,
elem1: &T,
elem2: &T,
union: U,
) -> Result<UnionStatus, UnionOrAddError<U::Err, T, V, M, E>>
where
T: Clone,
{
let parent1 = self
.find_shorten_or_add(elem1)
.map_err(UnionOrAddError::AddError)?;
let parent2 = self
.find_shorten_or_add(elem2)
.map_err(UnionOrAddError::AddError)?;
self.union_helper(parent1, parent2, union)
.map_err(UnionOrAddError::NotUnionable)
}
}
impl<T: Clone + PartialEq, V, M, R> UnionFind<T, V, M, ByRank<R, T>>
where
M: GrowableIdentityMapping<T>,
V: Default,
R: RankMapping<T>,
ByRank<R, T>: GrowableExtra<T, V>,
{
pub fn union_by_rank_or_add(
&mut self,
elem1: &T,
elem2: &T,
) -> Result<UnionStatus, AddErrorSimple<T, V, M, ByRank<R, T>>>
where
T: Clone,
{
let parent1 = self.find_shorten_or_add(elem1)?;
let parent2 = self.find_shorten_or_add(elem2)?;
Ok(self
.union_by_rank_helper(parent1, parent2)
.expect("this should never go wrong since we just added these"))
}
}
#[derive(Error, Debug)]
pub enum UnionError<Err> {
#[error("the first element given as an argument to union was not found in the union find")]
Elem1NotFound,
#[error("the second element given as an argument to union was not found in the union find")]
Elem2NotFound,
#[error("could not union elements")]
NotUnionable(Err),
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub enum UnionStatus {
AlreadyEquivalent,
PerformedUnion,
}
impl<T: PartialEq, V, M, E> UnionFind<T, V, M, E>
where
M: Mapping<T, T>,
{
fn union_helper<U: Union<T>>(
&mut self,
parent1: T,
parent2: T,
union: U,
) -> Result<UnionStatus, U::Err>
where
T: Clone,
{
if parent1 == parent2 {
return Ok(UnionStatus::AlreadyEquivalent);
}
let res = union.union(parent1.clone(), parent2.clone())?;
self.parent.set(parent1, res.clone());
self.parent.set(parent2, res);
Ok(UnionStatus::PerformedUnion)
}
pub fn union_by<U: Union<T>>(
&mut self,
elem1: &T,
elem2: &T,
union: U,
) -> Result<UnionStatus, UnionError<U::Err>>
where
T: Clone,
{
let parent1 = self.find_shorten(elem1).ok_or(UnionError::Elem1NotFound)?;
let parent2 = self.find_shorten(elem2).ok_or(UnionError::Elem2NotFound)?;
self.union_helper(parent1, parent2, union)
.map_err(UnionError::NotUnionable)
}
}
#[derive(Error, Debug)]
pub enum UnionByRankError {
#[error("the first element given as an argument to union was not found in the union find")]
Elem1NotFound,
#[error("the second element given as an argument to union was not found in the union find")]
Elem2NotFound,
}
impl<T, V, M, R> UnionFind<T, V, M, ByRank<R, T>>
where
M: Mapping<T, T>,
R: RankMapping<T>,
T: Clone + PartialEq,
{
pub fn union_by_rank(&mut self, elem1: &T, elem2: &T) -> Result<UnionStatus, UnionByRankError> {
let parent1 = self
.find_shorten(elem1)
.ok_or(UnionByRankError::Elem1NotFound)?;
let parent2 = self
.find_shorten(elem2)
.ok_or(UnionByRankError::Elem2NotFound)?;
self.union_by_rank_helper(parent1, parent2)
}
fn union_by_rank_helper(
&mut self,
parent1: T,
parent2: T,
) -> Result<UnionStatus, UnionByRankError>
where
T: Clone,
{
if parent1 == parent2 {
return Ok(UnionStatus::AlreadyEquivalent);
}
let rank1 = self
.extra
.rank(&parent1)
.ok_or(UnionByRankError::Elem1NotFound)?;
let rank2 = self
.extra
.rank(&parent2)
.ok_or(UnionByRankError::Elem2NotFound)?;
match rank1.cmp(&rank2) {
Ordering::Less => {
self.parent.set(parent1, parent2);
}
Ordering::Equal => {
self.parent.set(parent1, parent2.clone());
self.extra.set_rank(parent2, rank2 + 1);
}
Ordering::Greater => {
self.parent.set(parent2, parent1);
}
}
Ok(UnionStatus::PerformedUnion)
}
}
#[derive(Debug, Error, PartialEq)]
pub enum AddError<E, P> {
#[error("couldn't add element to parent mapping")]
Parent(P),
#[error("couldn't add element to extra mapping")]
Extra(#[source] E),
}
type AddErrorSimple<T, V, M, E> =
AddError<<E as GrowableExtra<T, V>>::AddError, <M as GrowableMapping<T, T>>::AddError>;
impl<T: Clone, V, M, E> UnionFind<T, V, M, E>
where
M: GrowableIdentityMapping<T>,
E: GrowableExtra<T, V>,
V: Default,
{
pub fn add(&mut self, elem: T) -> Result<(), AddErrorSimple<T, V, M, E>> {
self.parent
.add_identity(elem.clone())
.map_err(AddError::Parent)?;
self.extra
.add(elem, Default::default())
.map_err(AddError::Extra)?;
Ok(())
}
}
impl<T: Clone, V, M, E> UnionFind<T, V, M, E>
where
M: GrowableIdentityMapping<T>,
E: GrowableExtra<T, V>,
{
pub fn add_with_extra(&mut self, elem: T, extra: V) -> Result<(), AddErrorSimple<T, V, M, E>> {
self.parent
.add_identity(elem.clone())
.map_err(AddError::Parent)?;
self.extra.add(elem, extra).map_err(AddError::Extra)?;
Ok(())
}
}