#![warn(missing_docs)]
use std::collections::{HashMap, BTreeMap, HashSet, VecDeque};
use std::marker::PhantomData;
use std::{hash, borrow};
#[derive(Eq, PartialEq, Ord, PartialOrd, Clone, Copy, Debug)]
pub enum SplitMutError {
NoValue,
SameValue,
}
impl std::error::Error for SplitMutError {
fn description(&self) -> &'static str {
match self {
&SplitMutError::NoValue => "No value",
&SplitMutError::SameValue => "Duplicate values",
}
}
}
impl std::fmt::Display for SplitMutError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
use std::error::Error;
f.write_str(self.description())
}
}
type R<V> = Result<*mut V, SplitMutError>;
#[inline]
fn to_r<V>(s: Option<&mut V>) -> R<V> {
s.map(|s| s as *mut V).ok_or(SplitMutError::NoValue)
}
#[inline]
fn check_r<V>(a: &R<V>, b: R<V>) -> R<V> {
match (a, &b) {
(&Ok(ref aa), &Ok(ref bb)) => if aa == bb { return Err(SplitMutError::SameValue) },
_ => {},
}
b
}
#[inline]
unsafe fn from_r<'a, V>(a: R<V>) -> Result<&'a mut V, SplitMutError> { a.map(|aa| &mut *aa) }
pub unsafe trait SplitMut<K, V> {
fn get1_mut(&mut self, k1: K) -> Option<&mut V>;
unsafe fn get1_unchecked_mut(&mut self, k1: K) -> &mut V;
fn get2_mut(&mut self, k1: K, k2: K) -> (Result<&mut V, SplitMutError>, Result<&mut V, SplitMutError>) {
let p1 = to_r(self.get1_mut(k1));
let p2 = to_r(self.get1_mut(k2));
let p2 = check_r(&p1, p2);
unsafe { (from_r(p1), from_r(p2)) }
}
fn get3_mut(&mut self, k1: K, k2: K, k3: K) -> (Result<&mut V, SplitMutError>,
Result<&mut V, SplitMutError>, Result<&mut V, SplitMutError>) {
let p1 = to_r(self.get1_mut(k1));
let p2 = to_r(self.get1_mut(k2));
let p3 = to_r(self.get1_mut(k3));
let p2 = check_r(&p1, p2);
let p3 = check_r(&p1, p3);
let p3 = check_r(&p2, p3);
unsafe { (from_r(p1), from_r(p2), from_r(p3)) }
}
fn get4_mut(&mut self, k1: K, k2: K, k3: K, k4: K) -> (Result<&mut V, SplitMutError>,
Result<&mut V, SplitMutError>, Result<&mut V, SplitMutError>, Result<&mut V, SplitMutError>) {
let p1 = to_r(self.get1_mut(k1));
let p2 = to_r(self.get1_mut(k2));
let p3 = to_r(self.get1_mut(k3));
let p4 = to_r(self.get1_mut(k4));
let p2 = check_r(&p1, p2);
let p3 = check_r(&p1, p3);
let p3 = check_r(&p2, p3);
let p4 = check_r(&p1, p4);
let p4 = check_r(&p2, p4);
let p4 = check_r(&p3, p4);
unsafe { (from_r(p1), from_r(p2), from_r(p3), from_r(p4)) }
}
fn get_muts(&mut self) -> GetMuts<K, V, Self> { GetMuts(self, HashSet::new(), PhantomData) }
fn get_mut_iter<I: Iterator<Item=K>>(&mut self, i: I) -> GetMutIter<K, V, Self, I> { GetMutIter(self.get_muts(), i) }
unsafe fn get2_unchecked_mut(&mut self, k1: K, k2: K) -> (&mut V, &mut V) {
let p2 = self.get1_unchecked_mut(k2) as *mut V;
(self.get1_unchecked_mut(k1), &mut *p2)
}
unsafe fn get3_unchecked_mut(&mut self, k1: K, k2: K, k3: K) -> (&mut V, &mut V, &mut V) {
let p2 = self.get1_unchecked_mut(k2) as *mut V;
let p3 = self.get1_unchecked_mut(k3) as *mut V;
(self.get1_unchecked_mut(k1), &mut *p2, &mut *p3)
}
unsafe fn get4_unchecked_mut(&mut self, k1: K, k2: K, k3: K, k4: K) -> (&mut V, &mut V, &mut V, &mut V) {
let p2 = self.get1_unchecked_mut(k2) as *mut V;
let p3 = self.get1_unchecked_mut(k3) as *mut V;
let p4 = self.get1_unchecked_mut(k4) as *mut V;
(self.get1_unchecked_mut(k1), &mut *p2, &mut *p3, &mut *p4)
}
}
pub struct GetMuts<'a, K, V, A: 'a + SplitMut<K, V> + ?Sized>(&'a mut A, HashSet<*mut V>, PhantomData<*const K>);
impl<'a, K, V, A: 'a + SplitMut<K, V> + ?Sized> GetMuts<'a, K, V, A> {
pub fn at(&mut self, k: K) -> Result<&'a mut V, SplitMutError> {
let p = try!(to_r(self.0.get1_mut(k)));
if !self.1.insert(p) { return Err(SplitMutError::SameValue) };
Ok(unsafe { &mut *p })
}
}
pub struct GetMutIter<'a, K, V, A: 'a + SplitMut<K, V> + ?Sized, I>(GetMuts<'a, K, V, A>, I);
impl<'a, K, V: 'a, A: 'a + SplitMut<K, V> + ?Sized, I: Iterator<Item=K>> Iterator for GetMutIter<'a, K, V, A, I> {
type Item = Result<&'a mut V, SplitMutError>;
fn next(&mut self) -> Option<Self::Item> {
self.1.next().map(|k| self.0.at(k))
}
}
unsafe impl<'a, V> SplitMut<usize, V> for &'a mut [V] {
#[inline]
fn get1_mut(&mut self, k: usize) -> Option<&mut V> { self.get_mut(k) }
#[inline]
unsafe fn get1_unchecked_mut(&mut self, k: usize) -> &mut V { self.get_unchecked_mut(k) }
}
unsafe impl<'a, V> SplitMut<usize, V> for Vec<V> {
#[inline]
fn get1_mut(&mut self, k: usize) -> Option<&mut V> { self.get_mut(k) }
#[inline]
unsafe fn get1_unchecked_mut(&mut self, k: usize) -> &mut V { self.get_unchecked_mut(k) }
}
unsafe impl<'a, V> SplitMut<usize, V> for VecDeque<V> {
#[inline]
fn get1_mut(&mut self, k: usize) -> Option<&mut V> { self.get_mut(k) }
#[inline]
unsafe fn get1_unchecked_mut(&mut self, k: usize) -> &mut V { std::mem::transmute(self.get_mut(k)) }
}
unsafe impl<'a, K: hash::Hash + Eq + borrow::Borrow<Q>, Q: hash::Hash + Eq + ?Sized, V, S: hash::BuildHasher> SplitMut<&'a Q, V> for HashMap<K, V, S> {
#[inline]
fn get1_mut(&mut self, k: &'a Q) -> Option<&mut V> { self.get_mut(k) }
#[inline]
unsafe fn get1_unchecked_mut(&mut self, k: &'a Q) -> &mut V { std::mem::transmute(self.get_mut(k)) }
}
unsafe impl<'a, K: Ord + borrow::Borrow<Q>, Q: Ord + ?Sized, V> SplitMut<&'a Q, V> for BTreeMap<K, V> {
#[inline]
fn get1_mut(&mut self, k: &'a Q) -> Option<&mut V> { self.get_mut(k) }
#[inline]
unsafe fn get1_unchecked_mut(&mut self, k: &'a Q) -> &mut V { std::mem::transmute(self.get_mut(k)) }
}
#[test]
fn hash_same() {
let mut h = HashMap::new();
h.insert(3u8, 5u16);
assert_eq!(h.get2_mut(&3, &3), (Ok(&mut 5u16), Err(SplitMutError::SameValue)));
}
#[test]
fn hash_reg() {
let mut h = HashMap::new();
h.insert(3u8, 5u16);
h.insert(4u8, 9u16);
{ let (a, b) = h.get2_mut(&3, &4);
std::mem::swap(a.unwrap(), b.unwrap());
}
assert_eq!(h.get2_mut(&2, &2), (Err(SplitMutError::NoValue), Err(SplitMutError::NoValue)));
assert_eq!(unsafe { h.get2_unchecked_mut(&3, &4) }, (&mut 9u16, &mut 5u16));
assert_eq!(h.get2_mut(&2, &3), (Err(SplitMutError::NoValue), Ok(&mut 9u16)));
}
#[test]
fn tree_borrow() {
let mut h = BTreeMap::new();
h.insert(String::from("borrow"), 1);
h.insert(String::from("me"), 2);
let slice = ["me", "borrow", "me"];
let z: Vec<_> = h.get_mut_iter(slice.into_iter().map(|&k| k)).collect();
assert_eq!(&*z, [Ok(&mut 2), Ok(&mut 1), Err(SplitMutError::SameValue)]);
}
#[test]
fn deque_same() {
let mut h = VecDeque::new();
h.push_front(5u16);
assert_eq!(h.get2_mut(0, 0), (Ok(&mut 5u16), Err(SplitMutError::SameValue)));
}
#[test]
fn deque_reg() {
let mut h = VecDeque::new();
h.push_back(5u16);
h.push_back(9u16);
{ let (a, b) = h.get2_mut(0, 1);
std::mem::swap(a.unwrap(), b.unwrap());
}
assert_eq!(h.get2_mut(2, 2), (Err(SplitMutError::NoValue), Err(SplitMutError::NoValue)));
assert_eq!(unsafe { h.get2_unchecked_mut(0, 1) }, (&mut 9u16, &mut 5u16));
assert_eq!(h.get2_mut(2, 0), (Err(SplitMutError::NoValue), Ok(&mut 9u16)));
}
#[test]
fn vec() {
let mut h = vec!["Hello", "world", "!"];
{ let (a, b, c) = h.get3_mut(0, 1, 2);
*c.unwrap() = "universe";
std::mem::swap(a.unwrap(), b.unwrap());
}
assert_eq!(&*h, &["world", "Hello", "universe"]);
{
let mut z = h.get_muts();
let a = z.at(0);
let b = z.at(1);
assert_eq!(a, Ok(&mut "world"));
assert_eq!(b, Ok(&mut "Hello"));
std::mem::swap(a.unwrap(), b.unwrap());
assert_eq!(z.at(0), Err(SplitMutError::SameValue));
assert_eq!(z.at(3), Err(SplitMutError::NoValue));
}
assert_eq!(&*h, &["Hello", "world", "universe"]);
}