use rbe::{Context, Key, RbeTable, Ref, Value};
use std::collections::{HashMap, HashSet};
pub type Partitions<T, K, V, R, Ctx> = Vec<Partition<T, K, V, R, Ctx>>;
pub type Partition<T, K, V, R, Ctx> = (T, Vec<RbeTable<K, V, R, Ctx>>, Vec<(K, V, Ctx)>);
pub struct KPartitionIteratorMultiPredicate<T, F> {
items: Vec<T>,
k: usize,
current: Option<Vec<usize>>,
predicates: Vec<F>,
}
impl<T: Clone, F> KPartitionIteratorMultiPredicate<T, F>
where
F: Fn(&Vec<T>) -> bool,
{
pub fn new(items: Vec<T>, predicates: Vec<F>) -> Self {
let k = predicates.len();
let items: Vec<T> = items
.into_iter()
.filter(|item| predicates.iter().any(|p| p(&vec![item.clone()])))
.collect();
let current = Some(vec![0; items.len()]);
Self {
items,
k,
current,
predicates,
}
}
}
impl<T: Clone, F> Iterator for KPartitionIteratorMultiPredicate<T, F>
where
F: Fn(&Vec<T>) -> bool,
{
type Item = Vec<Vec<T>>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let assignment = self.current.as_ref()?;
let mut partitions = vec![Vec::new(); self.k];
for (item, &partition_idx) in self.items.iter().zip(assignment.iter()) {
partitions[partition_idx].push(item.clone());
}
let mut next_assignment = assignment.clone();
let mut carry = true;
for digit in next_assignment.iter_mut() {
if carry {
*digit += 1;
if *digit < self.k {
carry = false;
} else {
*digit = 0;
}
}
}
self.current = if carry { None } else { Some(next_assignment) };
let all_valid = partitions
.iter()
.zip(self.predicates.iter())
.all(|(subset, predicate)| predicate(subset));
if all_valid {
return Some(partitions);
}
}
}
}
pub fn partitions_iter<'a, T, K, V, R, Ctx>(
neighs: &'a [(K, V, Ctx)],
exprs: &'a HashMap<T, Vec<RbeTable<K, V, R, Ctx>>>,
) -> impl Iterator<Item = Partitions<T, K, V, R, Ctx>> + 'a
where
K: Key,
V: Value,
R: Ref,
Ctx: Context,
T: std::hash::Hash + Eq + Clone,
{
let conditions = build_conditions(exprs).collect::<Vec<_>>();
let iter_partitions = KPartitionIteratorMultiPredicate::new(neighs.to_owned(), conditions);
iter_partitions.map(|partition| {
partition
.into_iter()
.zip(exprs.iter())
.map(|(subset, (key, rbes))| (key.clone(), rbes.clone(), subset))
.collect()
})
}
fn build_conditions<'a, T, K, V, R, Ctx>(
triple_exprs: &'a HashMap<T, Vec<RbeTable<K, V, R, Ctx>>>,
) -> impl Iterator<Item = impl Fn(&Vec<(K, V, Ctx)>) -> bool> + 'a
where
K: Key,
V: Value,
R: Ref,
Ctx: Context,
T: std::hash::Hash + Eq + Clone,
{
triple_exprs.values().map(|rbes| {
let preds: Vec<K> = rbes
.iter()
.flat_map(|rbe| rbe.keys().cloned())
.collect::<HashSet<_>>()
.into_iter()
.collect();
move |subset: &Vec<(K, V, Ctx)>| subset.iter().all(|(p, _, _)| preds.contains(p))
})
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use crate::KPartitionIteratorMultiPredicate;
fn build_predicates<'a>(
triple_exprs: &'a HashMap<char, Vec<char>>,
) -> impl Iterator<Item = impl Fn(&Vec<(char, i32)>) -> bool> + 'a {
triple_exprs.values().map(|preds| {
let preds = preds.clone();
move |subset: &Vec<(char, i32)>| subset.iter().all(|(p, _)| preds.contains(p))
})
}
fn always_true<T>(_: &Vec<T>) -> bool {
true
}
#[tracing_test::traced_test]
#[test]
fn test_k_partitions_preds() {
let data: Vec<(char, i32)> = vec![('P', 1), ('P', 2), ('Q', 1), ('Q', 2)];
let triple_exprs = HashMap::from([('A', vec!['P', 'Q']), ('B', vec!['P']), ('C', vec!['Q'])]);
let predicates = build_predicates(&triple_exprs).collect::<Vec<_>>();
let mut count = 0;
for (i, partition) in KPartitionIteratorMultiPredicate::new(data.clone(), predicates).enumerate() {
println!("{}: {:?}", i, partition);
count += 1;
}
assert_eq!(count, 16);
}
#[tracing_test::traced_test]
#[test]
fn test_k_partitions_preds_empty() {
let data: Vec<(char, i32)> = vec![('R', 1)];
let triple_exprs = HashMap::from([('A', vec!['P', 'Q']), ('B', vec!['P']), ('C', vec!['Q'])]);
let predicates = build_predicates(&triple_exprs).collect::<Vec<_>>();
let mut count = 0;
for (i, partition) in KPartitionIteratorMultiPredicate::new(data.clone(), predicates).enumerate() {
println!("{}: {:?}", i, partition);
count += 1;
}
assert_eq!(count, 1);
}
#[test]
fn test_k2_n2_no_filter_count() {
let data = vec![1, 2];
let predicates: Vec<fn(&Vec<i32>) -> bool> = vec![always_true, always_true];
let count = KPartitionIteratorMultiPredicate::new(data, predicates).count();
assert_eq!(count, 4); }
#[test]
fn test_k2_n3_no_filter_count() {
let data = vec![1, 2, 3];
let predicates: Vec<fn(&Vec<i32>) -> bool> = vec![always_true, always_true];
let count = KPartitionIteratorMultiPredicate::new(data, predicates).count();
assert_eq!(count, 8); }
#[test]
fn test_k3_n2_no_filter_count() {
let data = vec![1, 2];
let predicates: Vec<fn(&Vec<i32>) -> bool> = vec![always_true, always_true, always_true];
let count = KPartitionIteratorMultiPredicate::new(data, predicates).count();
assert_eq!(count, 9); }
#[test]
fn test_k1_single_partition() {
let data = vec![1, 2, 3];
let predicates: Vec<fn(&Vec<i32>) -> bool> = vec![always_true];
let partitions: Vec<_> = KPartitionIteratorMultiPredicate::new(data.clone(), predicates).collect();
assert_eq!(partitions.len(), 1);
assert_eq!(partitions[0], vec![data]);
}
#[test]
fn test_single_item_exact_partitions() {
let data = vec![42];
let predicates: Vec<fn(&Vec<i32>) -> bool> = vec![always_true, always_true];
let partitions: Vec<_> = KPartitionIteratorMultiPredicate::new(data, predicates).collect();
assert_eq!(partitions.len(), 2);
assert_eq!(partitions[0], vec![vec![42], vec![]]);
assert_eq!(partitions[1], vec![vec![], vec![42]]);
}
#[test]
fn test_predicate_filters_items() {
let data = vec![1, 2, 3, 4];
let is_even: fn(&Vec<i32>) -> bool = |v| v.iter().all(|x| x % 2 == 0);
let is_odd: fn(&Vec<i32>) -> bool = |v| v.iter().all(|x| x % 2 != 0);
let partitions: Vec<_> = KPartitionIteratorMultiPredicate::new(data, vec![is_even, is_odd]).collect();
assert_eq!(partitions.len(), 1);
assert!(
partitions[0][0].iter().all(|x| x % 2 == 0),
"group 0 should contain only evens"
);
assert!(
partitions[0][1].iter().all(|x| x % 2 != 0),
"group 1 should contain only odds"
);
}
#[test]
fn test_reject_nonempty_forces_single_bucket() {
let data = vec![1, 2, 3];
let always_true_fn: fn(&Vec<i32>) -> bool = always_true;
let reject_nonempty: fn(&Vec<i32>) -> bool = |v| v.is_empty();
let partitions: Vec<_> =
KPartitionIteratorMultiPredicate::new(data.clone(), vec![always_true_fn, reject_nonempty]).collect();
assert_eq!(partitions.len(), 1);
assert_eq!(partitions[0][0], data);
assert!(partitions[0][1].is_empty());
}
}