use crate::Result;
use derivative::Derivative;
use failure::{bail, ensure};
use log::warn;
use std::borrow::Borrow;
use std::cmp::Ordering;
use std::collections::hash_map::Entry::*;
use std::collections::HashMap;
use std::iter;
use std::marker::PhantomData;
use std::ops;
use std::result::Result as StdResult;
use std::slice;
pub trait WithId {
fn with_id(id: &str) -> Self;
}
pub trait Id<T> {
fn id(&self) -> &str;
fn set_id(&mut self, id: String);
}
#[macro_export]
macro_rules! impl_id {
($ty:ty, $gen:ty, $id: ident) => {
impl transit_model_collection::Id<$gen> for $ty {
fn id(&self) -> &str {
&self.$id
}
fn set_id(&mut self, id: std::string::String) {
self.$id = id;
}
}
};
($ty:ty) => {
impl_id!($ty, $ty, id);
};
}
#[derive(Derivative, Debug)]
#[derivative(
Copy(bound = ""),
Clone(bound = ""),
PartialEq(bound = ""),
Eq(bound = ""),
Hash(bound = "")
)]
pub struct Idx<T>(u32, PhantomData<T>);
impl<T> Idx<T> {
fn new(idx: usize) -> Self {
Idx(idx as u32, PhantomData)
}
fn get(self) -> usize {
self.0 as usize
}
}
impl<T> Ord for Idx<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.0.cmp(&other.0)
}
}
impl<T> PartialOrd for Idx<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Derivative, Clone)]
#[derivative(Default(bound = ""))]
pub struct Collection<T> {
objects: Vec<T>,
}
impl<T> From<T> for Collection<T> {
fn from(object: T) -> Self {
Collection::new(vec![object])
}
}
impl<T: PartialEq> PartialEq for Collection<T> {
fn eq(&self, other: &Collection<T>) -> bool {
self.objects == other.objects
}
}
impl<T> Collection<T> {
pub fn new(v: Vec<T>) -> Self {
Collection { objects: v }
}
pub fn len(&self) -> usize {
self.objects.len()
}
pub fn iter(&self) -> Iter<'_, T> {
self.objects
.iter()
.enumerate()
.map(|(idx, obj)| (Idx::new(idx), obj))
}
pub fn values(&self) -> slice::Iter<'_, T> {
self.objects.iter()
}
pub fn values_mut(&mut self) -> slice::IterMut<'_, T> {
self.objects.iter_mut()
}
pub fn iter_from<I>(&self, indexes: I) -> impl Iterator<Item = &T>
where
I: IntoIterator,
I::Item: Borrow<Idx<T>>,
{
indexes
.into_iter()
.map(move |item| &self.objects[item.borrow().get()])
}
pub fn push(&mut self, item: T) -> Idx<T> {
let next_index = self.objects.len();
self.objects.push(item);
Idx::new(next_index)
}
pub fn merge(&mut self, other: Self) {
for item in other {
self.push(item);
}
}
pub fn take(&mut self) -> Vec<T> {
::std::mem::replace(&mut self.objects, Vec::new())
}
pub fn is_empty(&self) -> bool {
self.objects.is_empty()
}
pub fn retain<F: FnMut(&T) -> bool>(&mut self, f: F) {
let mut purged = self.take();
purged.retain(f);
*self = Self::new(purged);
}
}
pub type Iter<'a, T> =
iter::Map<iter::Enumerate<slice::Iter<'a, T>>, fn((usize, &T)) -> (Idx<T>, &T)>;
impl<'a, T> IntoIterator for &'a Collection<T> {
type Item = (Idx<T>, &'a T);
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Iter<'a, T> {
self.iter()
}
}
impl<T> IntoIterator for Collection<T> {
type Item = T;
type IntoIter = ::std::vec::IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
self.objects.into_iter()
}
}
impl<T> ops::Index<Idx<T>> for Collection<T> {
type Output = T;
fn index(&self, index: Idx<T>) -> &Self::Output {
&self.objects[index.get()]
}
}
impl<T> ::serde::Serialize for Collection<T>
where
T: ::serde::Serialize,
{
fn serialize<S>(&self, serializer: S) -> StdResult<S::Ok, S::Error>
where
S: ::serde::Serializer,
{
self.objects.serialize(serializer)
}
}
impl<'de, T> ::serde::Deserialize<'de> for Collection<T>
where
T: ::serde::Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> StdResult<Self, D::Error>
where
D: ::serde::Deserializer<'de>,
{
::serde::Deserialize::deserialize(deserializer).map(Collection::new)
}
}
#[derive(Debug, Derivative, Clone)]
#[derivative(Default(bound = ""))]
pub struct CollectionWithId<T> {
collection: Collection<T>,
id_to_idx: HashMap<String, Idx<T>>,
}
impl<T: Id<T>> From<T> for CollectionWithId<T> {
fn from(object: T) -> Self {
CollectionWithId::new(vec![object]).unwrap()
}
}
impl<T: Id<T>> CollectionWithId<T> {
pub fn new(v: Vec<T>) -> Result<Self> {
let mut id_to_idx = HashMap::default();
for (i, obj) in v.iter().enumerate() {
ensure!(
id_to_idx
.insert(obj.id().to_string(), Idx::new(i))
.is_none(),
"{} already found",
obj.id()
);
}
Ok(CollectionWithId {
collection: Collection::new(v),
id_to_idx,
})
}
pub fn get_id_to_idx(&self) -> &HashMap<String, Idx<T>> {
&self.id_to_idx
}
pub fn index_mut(&mut self, idx: Idx<T>) -> RefMut<'_, T> {
RefMut {
idx,
old_id: self.objects[idx.get()].id().to_string(),
collection: self,
}
}
pub fn get_mut(&mut self, id: &str) -> Option<RefMut<'_, T>> {
self.get_idx(id).map(move |idx| self.index_mut(idx))
}
pub fn push(&mut self, item: T) -> Result<Idx<T>> {
let next_index = self.collection.objects.len();
let idx = Idx::new(next_index);
match self.id_to_idx.entry(item.id().to_string()) {
Occupied(_) => bail!("{} already found", item.id()),
Vacant(v) => {
v.insert(idx);
self.collection.objects.push(item);
Ok(idx)
}
}
}
pub fn retain<F: FnMut(&T) -> bool>(&mut self, f: F) {
let mut purged = self.take();
purged.retain(f);
*self = Self::new(purged).unwrap(); }
pub fn try_merge(&mut self, other: Self) -> Result<()> {
for item in other {
self.push(item)?;
}
Ok(())
}
pub fn merge(&mut self, other: Self) {
for item in other {
match self.push(item) {
_ => continue,
}
}
}
pub fn merge_with<I, F>(&mut self, iterator: I, mut f: F)
where
F: FnMut(&mut T, &T),
I: IntoIterator<Item = T>,
{
for e in iterator {
if let Some(mut source) = self.get_mut(e.id()) {
use std::ops::DerefMut;
f(source.deref_mut(), &e);
continue;
}
self.push(e).unwrap();
}
}
pub fn is_empty(&self) -> bool {
self.collection.is_empty()
}
}
impl<T: Id<T> + WithId> CollectionWithId<T> {
pub fn get_or_create<'a>(&'a mut self, id: &str) -> RefMut<'a, T> {
self.get_or_create_with(id, || T::with_id(id))
}
}
impl<T: Id<T>> CollectionWithId<T> {
pub fn get_or_create_with<'a, F>(&'a mut self, id: &str, mut f: F) -> RefMut<'a, T>
where
F: FnMut() -> T,
{
let elt = self.get_idx(id).unwrap_or_else(|| {
let mut o = f();
o.set_id(id.to_string());
self.push(o).unwrap()
});
self.index_mut(elt)
}
}
impl<T: Id<T>> iter::Extend<T> for CollectionWithId<T> {
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
for item in iter {
match self.push(item) {
Ok(val) => val,
Err(e) => {
warn!("{}", e);
continue;
}
};
}
}
}
impl<T> CollectionWithId<T> {
pub fn contains_id(&self, id: &str) -> bool {
self.id_to_idx.contains_key(id)
}
pub fn get_idx(&self, id: &str) -> Option<Idx<T>> {
self.id_to_idx.get(id).cloned()
}
pub fn get(&self, id: &str) -> Option<&T> {
self.get_idx(id).map(|idx| &self[idx])
}
pub fn into_vec(self) -> Vec<T> {
self.collection.objects
}
pub fn take(&mut self) -> Vec<T> {
self.id_to_idx.clear();
::std::mem::replace(&mut self.collection.objects, Vec::new())
}
}
pub struct RefMut<'a, T: Id<T>> {
idx: Idx<T>,
collection: &'a mut CollectionWithId<T>,
old_id: String,
}
impl<'a, T: Id<T>> ops::DerefMut for RefMut<'a, T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.collection.collection.objects[self.idx.get()]
}
}
impl<'a, T: Id<T>> ops::Deref for RefMut<'a, T> {
type Target = T;
fn deref(&self) -> &T {
&self.collection.objects[self.idx.get()]
}
}
impl<'a, T: Id<T>> Drop for RefMut<'a, T> {
fn drop(&mut self) {
if self.id() != self.old_id {
self.collection.id_to_idx.remove(&self.old_id);
let new_id = self.id().to_string();
assert!(
self.collection.id_to_idx.insert(new_id, self.idx).is_none(),
"changing id {} to {} already used",
self.old_id,
self.id()
);
}
}
}
impl<T: PartialEq> PartialEq for CollectionWithId<T> {
fn eq(&self, other: &CollectionWithId<T>) -> bool {
self.collection == other.collection
}
}
impl<T> ops::Deref for CollectionWithId<T> {
type Target = Collection<T>;
fn deref(&self) -> &Collection<T> {
&self.collection
}
}
impl<'a, T> IntoIterator for &'a CollectionWithId<T> {
type Item = (Idx<T>, &'a T);
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<T> IntoIterator for CollectionWithId<T> {
type Item = T;
type IntoIter = ::std::vec::IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
self.collection.into_iter()
}
}
impl<T> ::serde::Serialize for CollectionWithId<T>
where
T: ::serde::Serialize + Id<T>,
{
fn serialize<S>(&self, serializer: S) -> StdResult<S::Ok, S::Error>
where
S: ::serde::Serializer,
{
self.objects.serialize(serializer)
}
}
impl<'de, T> ::serde::Deserialize<'de> for CollectionWithId<T>
where
T: ::serde::Deserialize<'de> + Id<T>,
{
fn deserialize<D>(deserializer: D) -> StdResult<Self, D::Error>
where
D: ::serde::Deserializer<'de>,
{
use serde::de::Error;
::serde::Deserialize::deserialize(deserializer)
.and_then(|v| CollectionWithId::new(v).map_err(D::Error::custom))
}
}