use std::{
borrow::Borrow,
collections::HashSet,
hash::{BuildHasher, Hash},
};
pub trait SetInsertExt<T> {
fn get_or_try_insert_with<Q: ?Sized, F, E>(&mut self, value: &Q, f: F) -> Result<&T, E>
where
T: Borrow<Q>,
Q: Hash + Eq,
F: FnOnce(&Q) -> Result<T, E>;
}
impl<T, S> SetInsertExt<T> for HashSet<T, S>
where
T: Eq + Hash,
S: BuildHasher,
{
fn get_or_try_insert_with<Q: ?Sized, F, E>(&mut self, value: &Q, f: F) -> Result<&T, E>
where
T: Borrow<Q>,
Q: Hash + Eq,
F: FnOnce(&Q) -> Result<T, E>,
{
if !self.contains(value) {
self.insert(f(value)?);
}
match self.get(value) {
Some(value) => Ok(value),
None => unsafe { core::hint::unreachable_unchecked() },
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::SetInsertExt;
#[test]
fn it_works_when_present() {
let mut set = HashSet::new();
set.insert(0);
assert_eq!(
set.get_or_try_insert_with::<_, _, ()>(&0, |_| Ok(1)),
Ok(&0)
);
assert_eq!(
set.get_or_try_insert_with::<_, _, ()>(&0, |_| Err(())),
Ok(&0)
);
}
#[test]
fn it_works_when_not_present() {
let mut set = HashSet::new();
assert_eq!(
set.get_or_try_insert_with::<_, _, ()>(&0, |_| Ok(0)),
Ok(&0),
);
}
#[test]
fn it_errors() {
let mut set = HashSet::<i32>::new();
assert_eq!(set.get_or_try_insert_with(&0, |_| Err(())), Err(()));
}
}