1use std::marker::PhantomData;
2use std::ops::Mul;
3
4use num_traits::PrimInt;
5use rayon::prelude::*;
6use smol_str::SmolStr;
7
8pub type IndexTuple = Box<[IndexKey]>;
17
18#[derive(Clone, Debug)]
22enum SetRepr {
23 Range(Vec<i64>),
24 Strings(Vec<SmolStr>),
25 Tuples(Vec<IndexTuple>),
26}
27
28impl SetRepr {
29 fn len(&self) -> usize {
30 match self {
31 Self::Range(v) => v.len(),
32 Self::Strings(v) => v.len(),
33 Self::Tuples(v) => v.len(),
34 }
35 }
36}
37
38#[derive(Clone, Copy, Debug, PartialEq, Eq)]
45pub struct Axis {
46 pub start: i64,
47 pub len: usize,
48}
49
50pub struct Set<K = IndexKey> {
65 repr: SetRepr,
66 axes: Option<Box<[Axis]>>,
67 _k: PhantomData<fn() -> K>,
68}
69
70impl<K> Clone for Set<K> {
72 fn clone(&self) -> Self {
73 Self { repr: self.repr.clone(), axes: self.axes.clone(), _k: PhantomData }
74 }
75}
76
77impl<K> std::fmt::Debug for Set<K> {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 std::fmt::Debug::fmt(&self.repr, f)
80 }
81}
82
83impl<K> Set<K> {
84 fn from_repr(repr: SetRepr) -> Self {
85 Self { repr, axes: None, _k: PhantomData }
86 }
87
88 fn from_repr_with_axes(repr: SetRepr, axes: Box<[Axis]>) -> Self {
89 Self { repr, axes: Some(axes), _k: PhantomData }
90 }
91
92 pub(crate) fn axes(&self) -> Option<&[Axis]> {
96 self.axes.as_deref()
97 }
98
99 pub fn tuples<I, T>(iter: I) -> Self
101 where
102 I: IntoIterator<Item = T>,
103 T: Into<IndexTuple>,
104 {
105 Self::from_repr(SetRepr::Tuples(iter.into_iter().map(Into::into).collect()))
106 }
107
108 #[must_use]
111 pub fn filter<F>(&self, mut f: F) -> Self
112 where
113 F: FnMut(&IndexKey) -> bool,
114 {
115 let repr = match &self.repr {
116 SetRepr::Range(v) => {
117 SetRepr::Range(v.iter().copied().filter(|i| f(&IndexKey::Int(*i))).collect())
118 }
119 SetRepr::Strings(v) => SetRepr::Strings(
120 v.iter()
121 .filter_map(|s| {
122 let key = IndexKey::Str(s.clone());
123 f(&key).then(|| s.clone())
124 })
125 .collect(),
126 ),
127 SetRepr::Tuples(v) => SetRepr::Tuples(
128 v.iter()
129 .filter_map(|t| {
130 let key = IndexKey::Tuple(t.clone());
131 f(&key).then(|| match key {
132 IndexKey::Tuple(owned) => owned,
133 _ => unreachable!(),
134 })
135 })
136 .collect(),
137 ),
138 };
139 Self::from_repr(repr)
140 }
141
142 #[must_use]
157 pub fn filter_typed<F>(&self, mut pred: F) -> Self
158 where
159 K: FromIndexKey,
160 F: FnMut(K) -> bool,
161 {
162 self.filter(|k| pred(K::from_index_key(k)))
163 }
164
165 pub fn len(&self) -> usize {
166 self.repr.len()
167 }
168
169 pub fn is_empty(&self) -> bool {
170 self.len() == 0
171 }
172
173 pub fn is_range(&self) -> bool {
175 matches!(self.repr, SetRepr::Range(_))
176 }
177
178 pub fn is_strings(&self) -> bool {
180 matches!(self.repr, SetRepr::Strings(_))
181 }
182
183 pub fn is_tuples(&self) -> bool {
185 matches!(self.repr, SetRepr::Tuples(_))
186 }
187
188 #[must_use]
194 pub fn product<B>(a: &Set<K>, b: &Set<B>) -> Set<<K as KeyCat<B>>::Out>
195 where
196 K: KeyCat<B>,
197 {
198 let a_len = a.len();
199 let b_len = b.len();
200 let total = a_len.checked_mul(b_len).expect("Set::product size overflow");
201
202 let axes = match (a.axes(), b.axes()) {
203 (Some(aa), Some(bb)) => {
204 let mut v = Vec::with_capacity(aa.len() + bb.len());
205 v.extend_from_slice(aa);
206 v.extend_from_slice(bb);
207 Some(v.into_boxed_slice())
208 }
209 _ => None,
210 };
211
212 const PAR_THRESHOLD: usize = 4096;
215 let out: Vec<IndexTuple> = if total < PAR_THRESHOLD {
216 let mut out = Vec::with_capacity(total);
217 for ka in a {
218 for kb in b {
219 let mut parts: Vec<IndexKey> = Vec::new();
220 push_flat(&mut parts, ka.clone());
221 push_flat(&mut parts, kb);
222 out.push(parts.into_boxed_slice());
223 }
224 }
225 out
226 } else {
227 let a_keys: Vec<IndexKey> = a.iter().collect();
228 let b_keys: Vec<IndexKey> = b.iter().collect();
229 (0..total)
230 .into_par_iter()
231 .map(|i| {
232 let mut parts: Vec<IndexKey> = Vec::new();
233 push_flat(&mut parts, a_keys[i / b_len].clone());
234 push_flat(&mut parts, b_keys[i % b_len].clone());
235 parts.into_boxed_slice()
236 })
237 .collect()
238 };
239
240 match axes {
241 Some(axes) => Set::from_repr_with_axes(SetRepr::Tuples(out), axes),
242 None => Set::from_repr(SetRepr::Tuples(out)),
243 }
244 }
245}
246
247impl Set<usize> {
248 #[must_use]
258 pub fn range<T: PrimInt>(r: std::ops::Range<T>) -> Self {
259 let start = r.start.to_i64().expect("range start out of i64 range");
260 let end = r.end.to_i64().expect("range end out of i64 range");
261 Self::dense_i64(start, end)
262 }
263
264 pub(crate) fn dense_i64(start: i64, end: i64) -> Self {
268 let vals: Vec<i64> = (start..end).collect();
269 let len = vals.len();
270 Self::from_repr_with_axes(SetRepr::Range(vals), Box::from([Axis { start, len }]))
271 }
272
273 pub fn from_ints<T, I>(iter: I) -> Self
279 where
280 T: PrimInt,
281 I: IntoIterator<Item = T>,
282 {
283 Self::from_repr(SetRepr::Range(
284 iter.into_iter().map(|v| v.to_i64().expect("element out of i64 range")).collect(),
285 ))
286 }
287}
288
289impl Set<String> {
290 pub fn strings<I, S>(iter: I) -> Self
291 where
292 I: IntoIterator<Item = S>,
293 S: Into<SmolStr>,
294 {
295 Self::from_repr(SetRepr::Strings(iter.into_iter().map(Into::into).collect()))
296 }
297}
298
299fn push_flat(dst: &mut Vec<IndexKey>, k: IndexKey) {
300 match k {
301 IndexKey::Tuple(inner) => dst.extend(inner.into_vec()),
302 other => dst.push(other),
303 }
304}
305
306fn make_tuple<I: IntoIterator<Item = IndexKey>>(items: I) -> IndexTuple {
307 let mut v: Vec<IndexKey> = Vec::new();
308 for k in items {
309 push_flat(&mut v, k);
310 }
311 v.into_boxed_slice()
312}
313
314impl<A, B> Mul<&Set<B>> for &Set<A>
315where
316 A: KeyCat<B>,
317{
318 type Output = Set<<A as KeyCat<B>>::Out>;
319 fn mul(self, rhs: &Set<B>) -> Self::Output {
320 Set::product(self, rhs)
321 }
322}
323
324#[diagnostic::on_unimplemented(
328 message = "cannot form a Cartesian product index key from `{Self}` and `{Rhs}`",
329 label = "no product key for `{Self}` * `{Rhs}`",
330 note = "`&a * &b` composes scalar keys (`usize`/`i64`/`i32`/`String`) into flat tuples up to arity 4. A 5th axis or a non-scalar operand is unsupported"
331)]
332pub trait KeyCat<Rhs> {
333 type Out;
334}
335
336pub trait ScalarKey {}
339impl ScalarKey for usize {}
340impl ScalarKey for i32 {}
341impl ScalarKey for i64 {}
342impl ScalarKey for String {}
343impl ScalarKey for IndexKey {}
344
345impl<A: ScalarKey, B: ScalarKey> KeyCat<B> for A {
346 type Out = (A, B);
347}
348
349impl<A, B, C: ScalarKey> KeyCat<C> for (A, B) {
350 type Out = (A, B, C);
351}
352
353impl<A, B, C, D: ScalarKey> KeyCat<D> for (A, B, C) {
354 type Out = (A, B, C, D);
355}
356
357impl<A: ScalarKey, B, C> KeyCat<(B, C)> for A {
362 type Out = (A, B, C);
363}
364
365impl<A: ScalarKey, B, C, D> KeyCat<(B, C, D)> for A {
366 type Out = (A, B, C, D);
367}
368
369impl<A, B, C, D> KeyCat<(C, D)> for (A, B) {
370 type Out = (A, B, C, D);
371}
372
373#[derive(Clone, Debug, PartialEq, Eq, Hash)]
375pub enum IndexKey {
376 Int(i64),
377 Str(SmolStr),
378 Tuple(IndexTuple),
379}
380
381impl IndexKey {
382 pub fn tuple<I, T>(iter: I) -> Self
385 where
386 I: IntoIterator<Item = T>,
387 T: Into<IndexKey>,
388 {
389 Self::Tuple(make_tuple(iter.into_iter().map(Into::into)))
390 }
391
392 pub fn as_i64(&self) -> Option<i64> {
393 if let Self::Int(v) = self { Some(*v) } else { None }
394 }
395
396 pub fn as_str(&self) -> Option<&str> {
397 if let Self::Str(s) = self { Some(s.as_str()) } else { None }
398 }
399
400 pub fn as_tuple(&self) -> Option<&[IndexKey]> {
401 if let Self::Tuple(t) = self { Some(&t[..]) } else { None }
402 }
403}
404
405impl From<i64> for IndexKey {
406 fn from(v: i64) -> Self {
407 Self::Int(v)
408 }
409}
410
411impl From<i32> for IndexKey {
412 fn from(v: i32) -> Self {
413 Self::Int(i64::from(v))
414 }
415}
416
417impl From<usize> for IndexKey {
418 fn from(v: usize) -> Self {
419 Self::Int(i64::try_from(v).expect("usize -> i64 overflow"))
420 }
421}
422
423impl From<&str> for IndexKey {
424 fn from(s: &str) -> Self {
425 Self::Str(SmolStr::new(s))
426 }
427}
428
429impl From<String> for IndexKey {
430 fn from(s: String) -> Self {
431 Self::Str(SmolStr::from(s))
432 }
433}
434
435impl From<&String> for IndexKey {
436 fn from(s: &String) -> Self {
437 Self::Str(SmolStr::new(s.as_str()))
438 }
439}
440
441impl From<&usize> for IndexKey {
443 fn from(v: &usize) -> Self {
444 Self::from(*v)
445 }
446}
447
448impl From<&i64> for IndexKey {
449 fn from(v: &i64) -> Self {
450 Self::Int(*v)
451 }
452}
453
454impl From<&i32> for IndexKey {
455 fn from(v: &i32) -> Self {
456 Self::Int(i64::from(*v))
457 }
458}
459
460impl From<&&str> for IndexKey {
461 fn from(s: &&str) -> Self {
462 Self::Str(SmolStr::new(*s))
463 }
464}
465
466impl From<&&String> for IndexKey {
467 fn from(s: &&String) -> Self {
468 Self::Str(SmolStr::new(s.as_str()))
469 }
470}
471
472impl<A, B> From<(A, B)> for IndexKey
473where
474 A: Into<IndexKey>,
475 B: Into<IndexKey>,
476{
477 fn from(t: (A, B)) -> Self {
478 Self::Tuple(make_tuple([t.0.into(), t.1.into()]))
479 }
480}
481
482impl<A, B, C> From<(A, B, C)> for IndexKey
483where
484 A: Into<IndexKey>,
485 B: Into<IndexKey>,
486 C: Into<IndexKey>,
487{
488 fn from(t: (A, B, C)) -> Self {
489 Self::Tuple(make_tuple([t.0.into(), t.1.into(), t.2.into()]))
490 }
491}
492
493impl<A, B, C, D> From<(A, B, C, D)> for IndexKey
494where
495 A: Into<IndexKey>,
496 B: Into<IndexKey>,
497 C: Into<IndexKey>,
498 D: Into<IndexKey>,
499{
500 fn from(t: (A, B, C, D)) -> Self {
501 Self::Tuple(make_tuple([t.0.into(), t.1.into(), t.2.into(), t.3.into()]))
502 }
503}
504
505#[diagnostic::on_unimplemented(
519 message = "`{Self}` is not a valid index key type",
520 label = "cannot be decoded from an index key",
521 note = "index keys decode to `usize`, `i64`, `i32`, `String`, `IndexKey`, or a tuple of those up to arity 4",
522 note = "annotate the binding to one of these (e.g. `for k: usize in set`) or match the `Set`'s key type"
523)]
524pub trait FromIndexKey: Sized {
525 fn from_index_key(k: &IndexKey) -> Self;
526}
527
528impl FromIndexKey for IndexKey {
529 fn from_index_key(k: &IndexKey) -> Self {
530 k.clone()
531 }
532}
533
534impl FromIndexKey for i64 {
535 fn from_index_key(k: &IndexKey) -> Self {
536 k.as_i64().unwrap_or_else(|| panic!("expected Int key, got {k:?}"))
537 }
538}
539
540impl FromIndexKey for i32 {
541 fn from_index_key(k: &IndexKey) -> Self {
542 let v = i64::from_index_key(k);
543 i32::try_from(v).unwrap_or_else(|_| panic!("key {v} out of i32 range"))
544 }
545}
546
547impl FromIndexKey for usize {
548 fn from_index_key(k: &IndexKey) -> Self {
549 let v = i64::from_index_key(k);
550 usize::try_from(v).unwrap_or_else(|_| panic!("key {v} out of usize range"))
551 }
552}
553
554impl FromIndexKey for String {
555 fn from_index_key(k: &IndexKey) -> Self {
556 k.as_str().unwrap_or_else(|| panic!("expected Str key, got {k:?}")).to_owned()
557 }
558}
559
560fn tuple_parts<'a>(k: &'a IndexKey, expected: usize) -> &'a [IndexKey] {
561 let p = k.as_tuple().unwrap_or_else(|| panic!("expected Tuple key, got {k:?}"));
562 assert_eq!(p.len(), expected, "expected tuple of arity {expected}, got arity {}", p.len());
563 p
564}
565
566impl<A, B> FromIndexKey for (A, B)
567where
568 A: FromIndexKey,
569 B: FromIndexKey,
570{
571 fn from_index_key(k: &IndexKey) -> Self {
572 let p = tuple_parts(k, 2);
573 (A::from_index_key(&p[0]), B::from_index_key(&p[1]))
574 }
575}
576
577impl<A, B, C> FromIndexKey for (A, B, C)
578where
579 A: FromIndexKey,
580 B: FromIndexKey,
581 C: FromIndexKey,
582{
583 fn from_index_key(k: &IndexKey) -> Self {
584 let p = tuple_parts(k, 3);
585 (A::from_index_key(&p[0]), B::from_index_key(&p[1]), C::from_index_key(&p[2]))
586 }
587}
588
589impl<A, B, C, D> FromIndexKey for (A, B, C, D)
590where
591 A: FromIndexKey,
592 B: FromIndexKey,
593 C: FromIndexKey,
594 D: FromIndexKey,
595{
596 fn from_index_key(k: &IndexKey) -> Self {
597 let p = tuple_parts(k, 4);
598 (
599 A::from_index_key(&p[0]),
600 B::from_index_key(&p[1]),
601 C::from_index_key(&p[2]),
602 D::from_index_key(&p[3]),
603 )
604 }
605}
606
607impl<'a, K> IntoIterator for &'a Set<K> {
608 type Item = IndexKey;
609 type IntoIter = SetIter<'a>;
610 fn into_iter(self) -> Self::IntoIter {
611 self.iter()
612 }
613}
614
615impl<K> Set<K> {
616 pub fn iter(&self) -> SetIter<'_> {
617 SetIter { repr: &self.repr, pos: 0 }
618 }
619
620 pub fn par_iter(&self) -> impl ParallelIterator<Item = IndexKey> + '_ {
621 match &self.repr {
622 SetRepr::Range(v) => v.par_iter().map(|i| IndexKey::Int(*i)).collect::<Vec<_>>(),
623 SetRepr::Strings(v) => {
624 v.par_iter().map(|s| IndexKey::Str(s.clone())).collect::<Vec<_>>()
625 }
626 SetRepr::Tuples(v) => {
627 v.par_iter().map(|t| IndexKey::Tuple(t.clone())).collect::<Vec<_>>()
628 }
629 }
630 .into_par_iter()
631 }
632}
633
634#[derive(Debug)]
635pub struct SetIter<'a> {
636 repr: &'a SetRepr,
637 pos: usize,
638}
639
640impl<'a> Iterator for SetIter<'a> {
641 type Item = IndexKey;
642 fn next(&mut self) -> Option<Self::Item> {
643 let out = match self.repr {
644 SetRepr::Range(v) => v.get(self.pos).copied().map(IndexKey::Int),
645 SetRepr::Strings(v) => v.get(self.pos).cloned().map(IndexKey::Str),
646 SetRepr::Tuples(v) => v.get(self.pos).cloned().map(IndexKey::Tuple),
647 };
648 if out.is_some() {
649 self.pos += 1;
650 }
651 out
652 }
653}