use std::marker::PhantomData;
use std::ops::Mul;
use num_traits::PrimInt;
use rayon::prelude::*;
use smol_str::SmolStr;
pub type IndexTuple = Box<[IndexKey]>;
#[derive(Clone, Debug)]
enum SetRepr {
Range(Vec<i64>),
Strings(Vec<SmolStr>),
Tuples(Vec<IndexTuple>),
}
impl SetRepr {
fn len(&self) -> usize {
match self {
Self::Range(v) => v.len(),
Self::Strings(v) => v.len(),
Self::Tuples(v) => v.len(),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct Axis {
pub start: i64,
pub len: usize,
}
pub struct Set<K = IndexKey> {
repr: SetRepr,
axes: Option<Box<[Axis]>>,
_k: PhantomData<fn() -> K>,
}
impl<K> Clone for Set<K> {
fn clone(&self) -> Self {
Self { repr: self.repr.clone(), axes: self.axes.clone(), _k: PhantomData }
}
}
impl<K> std::fmt::Debug for Set<K> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&self.repr, f)
}
}
impl<K> Set<K> {
fn from_repr(repr: SetRepr) -> Self {
Self { repr, axes: None, _k: PhantomData }
}
fn from_repr_with_axes(repr: SetRepr, axes: Box<[Axis]>) -> Self {
Self { repr, axes: Some(axes), _k: PhantomData }
}
pub(crate) fn axes(&self) -> Option<&[Axis]> {
self.axes.as_deref()
}
pub fn tuples<I, T>(iter: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<IndexTuple>,
{
Self::from_repr(SetRepr::Tuples(iter.into_iter().map(Into::into).collect()))
}
#[must_use]
pub fn filter<F>(&self, mut f: F) -> Self
where
F: FnMut(&IndexKey) -> bool,
{
let repr = match &self.repr {
SetRepr::Range(v) => {
SetRepr::Range(v.iter().copied().filter(|i| f(&IndexKey::Int(*i))).collect())
}
SetRepr::Strings(v) => SetRepr::Strings(
v.iter()
.filter_map(|s| {
let key = IndexKey::Str(s.clone());
f(&key).then(|| s.clone())
})
.collect(),
),
SetRepr::Tuples(v) => SetRepr::Tuples(
v.iter()
.filter_map(|t| {
let key = IndexKey::Tuple(t.clone());
f(&key).then(|| match key {
IndexKey::Tuple(owned) => owned,
_ => unreachable!(),
})
})
.collect(),
),
};
Self::from_repr(repr)
}
#[must_use]
pub fn filter_typed<F>(&self, mut pred: F) -> Self
where
K: FromIndexKey,
F: FnMut(K) -> bool,
{
self.filter(|k| pred(K::from_index_key(k)))
}
pub fn len(&self) -> usize {
self.repr.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_range(&self) -> bool {
matches!(self.repr, SetRepr::Range(_))
}
pub fn is_strings(&self) -> bool {
matches!(self.repr, SetRepr::Strings(_))
}
pub fn is_tuples(&self) -> bool {
matches!(self.repr, SetRepr::Tuples(_))
}
#[must_use]
pub fn product<B>(a: &Set<K>, b: &Set<B>) -> Set<<K as KeyCat<B>>::Out>
where
K: KeyCat<B>,
{
let a_len = a.len();
let b_len = b.len();
let total = a_len.checked_mul(b_len).expect("Set::product size overflow");
let axes = match (a.axes(), b.axes()) {
(Some(aa), Some(bb)) => {
let mut v = Vec::with_capacity(aa.len() + bb.len());
v.extend_from_slice(aa);
v.extend_from_slice(bb);
Some(v.into_boxed_slice())
}
_ => None,
};
const PAR_THRESHOLD: usize = 4096;
let out: Vec<IndexTuple> = if total < PAR_THRESHOLD {
let mut out = Vec::with_capacity(total);
for ka in a {
for kb in b {
let mut parts: Vec<IndexKey> = Vec::new();
push_flat(&mut parts, ka.clone());
push_flat(&mut parts, kb);
out.push(parts.into_boxed_slice());
}
}
out
} else {
let a_keys: Vec<IndexKey> = a.iter().collect();
let b_keys: Vec<IndexKey> = b.iter().collect();
(0..total)
.into_par_iter()
.map(|i| {
let mut parts: Vec<IndexKey> = Vec::new();
push_flat(&mut parts, a_keys[i / b_len].clone());
push_flat(&mut parts, b_keys[i % b_len].clone());
parts.into_boxed_slice()
})
.collect()
};
match axes {
Some(axes) => Set::from_repr_with_axes(SetRepr::Tuples(out), axes),
None => Set::from_repr(SetRepr::Tuples(out)),
}
}
}
impl Set<usize> {
#[must_use]
pub fn range<T: PrimInt>(r: std::ops::Range<T>) -> Self {
let start = r.start.to_i64().expect("range start out of i64 range");
let end = r.end.to_i64().expect("range end out of i64 range");
Self::dense_i64(start, end)
}
pub(crate) fn dense_i64(start: i64, end: i64) -> Self {
let vals: Vec<i64> = (start..end).collect();
let len = vals.len();
Self::from_repr_with_axes(SetRepr::Range(vals), Box::from([Axis { start, len }]))
}
pub fn from_ints<T, I>(iter: I) -> Self
where
T: PrimInt,
I: IntoIterator<Item = T>,
{
Self::from_repr(SetRepr::Range(
iter.into_iter().map(|v| v.to_i64().expect("element out of i64 range")).collect(),
))
}
}
impl Set<String> {
pub fn strings<I, S>(iter: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<SmolStr>,
{
Self::from_repr(SetRepr::Strings(iter.into_iter().map(Into::into).collect()))
}
}
fn push_flat(dst: &mut Vec<IndexKey>, k: IndexKey) {
match k {
IndexKey::Tuple(inner) => dst.extend(inner.into_vec()),
other => dst.push(other),
}
}
fn make_tuple<I: IntoIterator<Item = IndexKey>>(items: I) -> IndexTuple {
let mut v: Vec<IndexKey> = Vec::new();
for k in items {
push_flat(&mut v, k);
}
v.into_boxed_slice()
}
impl<A, B> Mul<&Set<B>> for &Set<A>
where
A: KeyCat<B>,
{
type Output = Set<<A as KeyCat<B>>::Out>;
fn mul(self, rhs: &Set<B>) -> Self::Output {
Set::product(self, rhs)
}
}
#[diagnostic::on_unimplemented(
message = "cannot form a Cartesian product index key from `{Self}` and `{Rhs}`",
label = "no product key for `{Self}` * `{Rhs}`",
note = "`&a * &b` composes scalar keys (`usize`/`i64`/`i32`/`String`) into flat tuples up to arity 4. A 5th axis or a non-scalar operand is unsupported"
)]
pub trait KeyCat<Rhs> {
type Out;
}
pub trait ScalarKey {}
impl ScalarKey for usize {}
impl ScalarKey for i32 {}
impl ScalarKey for i64 {}
impl ScalarKey for String {}
impl ScalarKey for IndexKey {}
impl<A: ScalarKey, B: ScalarKey> KeyCat<B> for A {
type Out = (A, B);
}
impl<A, B, C: ScalarKey> KeyCat<C> for (A, B) {
type Out = (A, B, C);
}
impl<A, B, C, D: ScalarKey> KeyCat<D> for (A, B, C) {
type Out = (A, B, C, D);
}
impl<A: ScalarKey, B, C> KeyCat<(B, C)> for A {
type Out = (A, B, C);
}
impl<A: ScalarKey, B, C, D> KeyCat<(B, C, D)> for A {
type Out = (A, B, C, D);
}
impl<A, B, C, D> KeyCat<(C, D)> for (A, B) {
type Out = (A, B, C, D);
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum IndexKey {
Int(i64),
Str(SmolStr),
Tuple(IndexTuple),
}
impl IndexKey {
pub fn tuple<I, T>(iter: I) -> Self
where
I: IntoIterator<Item = T>,
T: Into<IndexKey>,
{
Self::Tuple(make_tuple(iter.into_iter().map(Into::into)))
}
pub fn as_i64(&self) -> Option<i64> {
if let Self::Int(v) = self { Some(*v) } else { None }
}
pub fn as_str(&self) -> Option<&str> {
if let Self::Str(s) = self { Some(s.as_str()) } else { None }
}
pub fn as_tuple(&self) -> Option<&[IndexKey]> {
if let Self::Tuple(t) = self { Some(&t[..]) } else { None }
}
}
impl From<i64> for IndexKey {
fn from(v: i64) -> Self {
Self::Int(v)
}
}
impl From<i32> for IndexKey {
fn from(v: i32) -> Self {
Self::Int(i64::from(v))
}
}
impl From<usize> for IndexKey {
fn from(v: usize) -> Self {
Self::Int(i64::try_from(v).expect("usize -> i64 overflow"))
}
}
impl From<&str> for IndexKey {
fn from(s: &str) -> Self {
Self::Str(SmolStr::new(s))
}
}
impl From<String> for IndexKey {
fn from(s: String) -> Self {
Self::Str(SmolStr::from(s))
}
}
impl From<&String> for IndexKey {
fn from(s: &String) -> Self {
Self::Str(SmolStr::new(s.as_str()))
}
}
impl From<&usize> for IndexKey {
fn from(v: &usize) -> Self {
Self::from(*v)
}
}
impl From<&i64> for IndexKey {
fn from(v: &i64) -> Self {
Self::Int(*v)
}
}
impl From<&i32> for IndexKey {
fn from(v: &i32) -> Self {
Self::Int(i64::from(*v))
}
}
impl From<&&str> for IndexKey {
fn from(s: &&str) -> Self {
Self::Str(SmolStr::new(*s))
}
}
impl From<&&String> for IndexKey {
fn from(s: &&String) -> Self {
Self::Str(SmolStr::new(s.as_str()))
}
}
impl<A, B> From<(A, B)> for IndexKey
where
A: Into<IndexKey>,
B: Into<IndexKey>,
{
fn from(t: (A, B)) -> Self {
Self::Tuple(make_tuple([t.0.into(), t.1.into()]))
}
}
impl<A, B, C> From<(A, B, C)> for IndexKey
where
A: Into<IndexKey>,
B: Into<IndexKey>,
C: Into<IndexKey>,
{
fn from(t: (A, B, C)) -> Self {
Self::Tuple(make_tuple([t.0.into(), t.1.into(), t.2.into()]))
}
}
impl<A, B, C, D> From<(A, B, C, D)> for IndexKey
where
A: Into<IndexKey>,
B: Into<IndexKey>,
C: Into<IndexKey>,
D: Into<IndexKey>,
{
fn from(t: (A, B, C, D)) -> Self {
Self::Tuple(make_tuple([t.0.into(), t.1.into(), t.2.into(), t.3.into()]))
}
}
#[diagnostic::on_unimplemented(
message = "`{Self}` is not a valid index key type",
label = "cannot be decoded from an index key",
note = "index keys decode to `usize`, `i64`, `i32`, `String`, `IndexKey`, or a tuple of those up to arity 4",
note = "annotate the binding to one of these (e.g. `for k: usize in set`) or match the `Set`'s key type"
)]
pub trait FromIndexKey: Sized {
fn from_index_key(k: &IndexKey) -> Self;
}
impl FromIndexKey for IndexKey {
fn from_index_key(k: &IndexKey) -> Self {
k.clone()
}
}
impl FromIndexKey for i64 {
fn from_index_key(k: &IndexKey) -> Self {
k.as_i64().unwrap_or_else(|| panic!("expected Int key, got {k:?}"))
}
}
impl FromIndexKey for i32 {
fn from_index_key(k: &IndexKey) -> Self {
let v = i64::from_index_key(k);
i32::try_from(v).unwrap_or_else(|_| panic!("key {v} out of i32 range"))
}
}
impl FromIndexKey for usize {
fn from_index_key(k: &IndexKey) -> Self {
let v = i64::from_index_key(k);
usize::try_from(v).unwrap_or_else(|_| panic!("key {v} out of usize range"))
}
}
impl FromIndexKey for String {
fn from_index_key(k: &IndexKey) -> Self {
k.as_str().unwrap_or_else(|| panic!("expected Str key, got {k:?}")).to_owned()
}
}
fn tuple_parts<'a>(k: &'a IndexKey, expected: usize) -> &'a [IndexKey] {
let p = k.as_tuple().unwrap_or_else(|| panic!("expected Tuple key, got {k:?}"));
assert_eq!(p.len(), expected, "expected tuple of arity {expected}, got arity {}", p.len());
p
}
impl<A, B> FromIndexKey for (A, B)
where
A: FromIndexKey,
B: FromIndexKey,
{
fn from_index_key(k: &IndexKey) -> Self {
let p = tuple_parts(k, 2);
(A::from_index_key(&p[0]), B::from_index_key(&p[1]))
}
}
impl<A, B, C> FromIndexKey for (A, B, C)
where
A: FromIndexKey,
B: FromIndexKey,
C: FromIndexKey,
{
fn from_index_key(k: &IndexKey) -> Self {
let p = tuple_parts(k, 3);
(A::from_index_key(&p[0]), B::from_index_key(&p[1]), C::from_index_key(&p[2]))
}
}
impl<A, B, C, D> FromIndexKey for (A, B, C, D)
where
A: FromIndexKey,
B: FromIndexKey,
C: FromIndexKey,
D: FromIndexKey,
{
fn from_index_key(k: &IndexKey) -> Self {
let p = tuple_parts(k, 4);
(
A::from_index_key(&p[0]),
B::from_index_key(&p[1]),
C::from_index_key(&p[2]),
D::from_index_key(&p[3]),
)
}
}
impl<'a, K> IntoIterator for &'a Set<K> {
type Item = IndexKey;
type IntoIter = SetIter<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<K> Set<K> {
pub fn iter(&self) -> SetIter<'_> {
SetIter { repr: &self.repr, pos: 0 }
}
pub fn par_iter(&self) -> impl ParallelIterator<Item = IndexKey> + '_ {
match &self.repr {
SetRepr::Range(v) => v.par_iter().map(|i| IndexKey::Int(*i)).collect::<Vec<_>>(),
SetRepr::Strings(v) => {
v.par_iter().map(|s| IndexKey::Str(s.clone())).collect::<Vec<_>>()
}
SetRepr::Tuples(v) => {
v.par_iter().map(|t| IndexKey::Tuple(t.clone())).collect::<Vec<_>>()
}
}
.into_par_iter()
}
}
#[derive(Debug)]
pub struct SetIter<'a> {
repr: &'a SetRepr,
pos: usize,
}
impl<'a> Iterator for SetIter<'a> {
type Item = IndexKey;
fn next(&mut self) -> Option<Self::Item> {
let out = match self.repr {
SetRepr::Range(v) => v.get(self.pos).copied().map(IndexKey::Int),
SetRepr::Strings(v) => v.get(self.pos).cloned().map(IndexKey::Str),
SetRepr::Tuples(v) => v.get(self.pos).cloned().map(IndexKey::Tuple),
};
if out.is_some() {
self.pos += 1;
}
out
}
}