use crate::cacheable::Cacheable;
use crate::predicate::BasicPredicate;
use std::any::{Any, TypeId};
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet, hash_map::DefaultHasher};
use std::hash::{Hash, Hasher};
use std::panic::{AssertUnwindSafe, catch_unwind};
use std::sync::Arc;
type ArcValue<T> = Arc<T>;
type PredicateFn<T> = dyn Fn(&T) -> bool + Send + Sync;
type MapFn<T> = dyn Fn(ArcValue<T>) -> ArcValue<T> + Send + Sync;
type FlatMapFn<T> = dyn Fn(ArcValue<T>) -> Vec<ArcValue<T>> + Send + Sync;
type CompareFn<T> = dyn Fn(&ArcValue<T>, &ArcValue<T>) -> Ordering + Send + Sync;
type KeyFn<T> = dyn Fn(&T) -> ErasedKey + Send + Sync;
type FoldFn<T> = dyn Fn(Vec<ArcValue<T>>) -> Option<ArcValue<T>> + Send + Sync;
#[non_exhaustive]
#[derive(Clone)]
pub enum MemQ<T: Cacheable> {
Filter(Filter<T>),
Map(Map<T>),
FlatMap(FlatMap<T>),
Take(Take),
Skip(Skip),
Chain(Chain<T>),
Sort(Sort<T>),
Unique(Unique<T>),
GroupBy(GroupBy<T>),
Partition(Partition<T>),
Fold(Fold<T>),
}
#[derive(Clone)]
pub struct Filter<T: Cacheable> {
kind: FilterKind<T>,
}
#[derive(Clone)]
enum FilterKind<T: Cacheable> {
Basic(BasicPredicate<T>),
Closure(Arc<PredicateFn<T>>),
}
impl<T: Cacheable> Filter<T> {
fn evaluate(&self, value: &T) -> bool {
match &self.kind {
FilterKind::Basic(predicate) => predicate.evaluate(value),
FilterKind::Closure(predicate) => {
catch_unwind(AssertUnwindSafe(|| predicate(value))).unwrap_or(false)
}
}
}
}
#[derive(Clone)]
pub struct Map<T: Cacheable> {
map: Arc<MapFn<T>>,
}
#[derive(Clone)]
pub struct FlatMap<T: Cacheable> {
flat_map: Arc<FlatMapFn<T>>,
}
#[derive(Clone)]
pub struct Take {
count: usize,
}
#[derive(Clone)]
pub struct Skip {
count: usize,
}
#[derive(Clone)]
pub struct Chain<T: Cacheable> {
values: Vec<Arc<T>>,
}
#[derive(Clone)]
pub struct Sort<T: Cacheable> {
compare: Arc<CompareFn<T>>,
}
#[derive(Clone)]
pub struct Unique<T: Cacheable> {
key: Arc<KeyFn<T>>,
}
#[derive(Clone)]
pub struct GroupBy<T: Cacheable> {
key: Arc<KeyFn<T>>,
}
#[derive(Clone)]
pub struct Partition<T: Cacheable> {
predicate: Arc<PredicateFn<T>>,
}
#[derive(Clone)]
pub struct Fold<T: Cacheable> {
fold: Arc<FoldFn<T>>,
}
impl<T: Cacheable> MemQ<T> {
pub fn filter<F>(predicate: F) -> Self
where
F: Fn(&T) -> bool + Send + Sync + 'static,
{
Self::Filter(Filter {
kind: FilterKind::Closure(Arc::new(predicate)),
})
}
pub fn filter_basic(predicate: BasicPredicate<T>) -> Self {
Self::Filter(Filter {
kind: FilterKind::Basic(predicate),
})
}
pub fn map_arc<F>(map: F) -> Self
where
F: Fn(Arc<T>) -> Arc<T> + Send + Sync + 'static,
{
Self::Map(Map { map: Arc::new(map) })
}
pub fn flat_map_arc<F>(flat_map: F) -> Self
where
F: Fn(Arc<T>) -> Vec<Arc<T>> + Send + Sync + 'static,
{
Self::FlatMap(FlatMap {
flat_map: Arc::new(flat_map),
})
}
pub fn take(count: usize) -> Self {
Self::Take(Take { count })
}
pub fn skip(count: usize) -> Self {
Self::Skip(Skip { count })
}
pub fn chain<I>(values: I) -> Self
where
I: IntoIterator<Item = Arc<T>>,
{
Self::Chain(Chain {
values: values.into_iter().collect(),
})
}
pub fn sort<F>(compare: F) -> Self
where
F: Fn(&Arc<T>, &Arc<T>) -> Ordering + Send + Sync + 'static,
{
Self::Sort(Sort {
compare: Arc::new(compare),
})
}
pub fn unique() -> Self {
Self::unique_by(|value: &T| value.id())
}
pub fn unique_by<K, F>(key: F) -> Self
where
K: Eq + Hash + Send + Sync + 'static,
F: Fn(&T) -> K + Send + Sync + 'static,
{
Self::Unique(Unique {
key: Arc::new(move |value| ErasedKey::new(key(value))),
})
}
pub fn group_by<K, F>(key: F) -> Self
where
K: Eq + Hash + Send + Sync + 'static,
F: Fn(&T) -> K + Send + Sync + 'static,
{
Self::GroupBy(GroupBy {
key: Arc::new(move |value| ErasedKey::new(key(value))),
})
}
pub fn partition<F>(predicate: F) -> Self
where
F: Fn(&T) -> bool + Send + Sync + 'static,
{
Self::Partition(Partition {
predicate: Arc::new(predicate),
})
}
pub fn fold<F>(fold: F) -> Self
where
F: Fn(Vec<Arc<T>>) -> Option<Arc<T>> + Send + Sync + 'static,
{
Self::Fold(Fold {
fold: Arc::new(fold),
})
}
pub fn apply(&self, values: Vec<Arc<T>>) -> Vec<Arc<T>> {
match self {
Self::Filter(op) => values
.into_iter()
.filter(|value| op.evaluate(value))
.collect(),
Self::Map(op) => values.into_iter().map(|value| (op.map)(value)).collect(),
Self::FlatMap(op) => values
.into_iter()
.flat_map(|value| (op.flat_map)(value))
.collect(),
Self::Take(op) => values.into_iter().take(op.count).collect(),
Self::Skip(op) => values.into_iter().skip(op.count).collect(),
Self::Chain(op) => values
.into_iter()
.chain(op.values.iter().cloned())
.collect(),
Self::Sort(op) => {
let mut values = values;
values.sort_by(|left, right| (op.compare)(left, right));
values
}
Self::Unique(op) => unique_by(values, &op.key),
Self::GroupBy(op) => group_by(values, &op.key),
Self::Partition(op) => partition(values, &op.predicate),
Self::Fold(op) => (op.fold)(values).into_iter().collect(),
}
}
pub fn apply_all(ops: &[Self], values: Vec<Arc<T>>) -> Vec<Arc<T>> {
ops.iter().fold(values, |values, op| op.apply(values))
}
}
impl<T> MemQ<T>
where
T: Cacheable + Clone,
{
pub fn map<F>(map: F) -> Self
where
F: Fn(&T) -> T + Send + Sync + 'static,
{
Self::map_arc(move |value| Arc::new(map(&value)))
}
pub fn flat_map<F>(flat_map: F) -> Self
where
F: Fn(&T) -> Vec<T> + Send + Sync + 'static,
{
Self::flat_map_arc(move |value| flat_map(&value).into_iter().map(Arc::new).collect())
}
pub fn chain_values<I>(values: I) -> Self
where
I: IntoIterator<Item = T>,
{
Self::chain(values.into_iter().map(Arc::new))
}
}
impl<T: Cacheable> From<BasicPredicate<T>> for MemQ<T> {
fn from(predicate: BasicPredicate<T>) -> Self {
Self::filter_basic(predicate)
}
}
impl<T: Cacheable> MemQ<T> {
pub fn sort_by_key<K, F>(key: F) -> Self
where
K: Ord,
F: Fn(&T) -> K + Send + Sync + 'static,
{
Self::sort(move |left, right| key(left).cmp(&key(right)))
}
}
fn unique_by<T: Cacheable>(values: Vec<Arc<T>>, key: &Arc<KeyFn<T>>) -> Vec<Arc<T>> {
let mut seen = HashSet::new();
values
.into_iter()
.filter(|value| seen.insert(key(value)))
.collect()
}
fn group_by<T: Cacheable>(values: Vec<Arc<T>>, key: &Arc<KeyFn<T>>) -> Vec<Arc<T>> {
let mut index_by_key: HashMap<ErasedKey, usize> = HashMap::new();
let mut groups: Vec<Vec<Arc<T>>> = Vec::new();
for value in values {
let key = key(&value);
if let Some(index) = index_by_key.get(&key).copied() {
groups[index].push(value);
} else {
let index = groups.len();
index_by_key.insert(key, index);
groups.push(vec![value]);
}
}
groups.into_iter().flatten().collect()
}
fn partition<T: Cacheable>(values: Vec<Arc<T>>, predicate: &Arc<PredicateFn<T>>) -> Vec<Arc<T>> {
let mut matches = Vec::new();
let mut misses = Vec::new();
for value in values {
if catch_unwind(AssertUnwindSafe(|| predicate(&value))).unwrap_or(false) {
matches.push(value);
} else {
misses.push(value);
}
}
matches.extend(misses);
matches
}
struct ErasedKey {
type_id: TypeId,
hash: u64,
value: Box<dyn Any + Send + Sync>,
eq: fn(&(dyn Any + Send + Sync), &(dyn Any + Send + Sync)) -> bool,
}
impl ErasedKey {
fn new<K>(value: K) -> Self
where
K: Eq + Hash + Send + Sync + 'static,
{
let mut hasher = DefaultHasher::new();
value.hash(&mut hasher);
Self {
type_id: TypeId::of::<K>(),
hash: hasher.finish(),
value: Box::new(value),
eq: erased_eq::<K>,
}
}
}
impl PartialEq for ErasedKey {
fn eq(&self, other: &Self) -> bool {
self.type_id == other.type_id
&& self.hash == other.hash
&& (self.eq)(&*self.value, &*other.value)
}
}
impl Eq for ErasedKey {}
impl Hash for ErasedKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.type_id.hash(state);
self.hash.hash(state);
}
}
fn erased_eq<K>(left: &(dyn Any + Send + Sync), right: &(dyn Any + Send + Sync)) -> bool
where
K: Eq + 'static,
{
match (left.downcast_ref::<K>(), right.downcast_ref::<K>()) {
(Some(left), Some(right)) => left == right,
_ => false,
}
}