afarray/
coords.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::iter::IntoIterator;
4use std::mem;
5use std::ops::{Deref, DerefMut};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use arrayfire as af;
10use futures::ready;
11use futures::stream::{Fuse, FusedStream, Stream, StreamExt, TryStream, TryStreamExt};
12use pin_project::pin_project;
13
14use super::{coord_bounds, ArrayExt};
15
16/// An n-dimensional coordinate.
17pub type Coord = Vec<u64>;
18
19/// One-dimensional array indices corresponding to n-dimensional coordinates.
20pub type Offsets = ArrayExt<u64>;
21
22/// A hardware-accelerated set of n-dimensional coordinates, all with the same dimension.
23///
24/// TODO: separate out a `CoordBasis` struct
25#[derive(Clone)]
26pub struct Coords {
27    array: af::Array<u64>,
28    ndim: usize,
29}
30
31impl Coords {
32    /// Constructs `Coords` with the given `size` full of zeros (origin points) for the given shape.
33    ///
34    /// Panics: if shape is empty
35    pub fn empty(shape: &[u64], size: usize) -> Self {
36        assert!(!shape.is_empty());
37        assert!(size > 0);
38
39        let ndim = shape.len();
40        let dims = af::Dim4::new(&[ndim as u64, size as u64, 1, 1]);
41        let array = af::constant(0u64, dims);
42        assert_eq!(array.dims(), dims);
43        Self { array, ndim }
44    }
45
46    /// Constructs a new `Coords` from an iterator of [`Coord`]s.
47    ///
48    /// Panics: if any [`Coord`] is not of length `ndim`, or if `ndim` is zero.
49    pub fn from_iter<I: IntoIterator<Item = Coord>>(iter: I, ndim: usize) -> Self {
50        assert!(ndim > 0);
51
52        let buffer: Vec<u64> = iter
53            .into_iter()
54            .inspect(|coord| assert_eq!(coord.len(), ndim))
55            .flatten()
56            .collect();
57
58        let num_coords = buffer.len() / ndim;
59        let dims = af::Dim4::new(&[ndim as u64, num_coords as u64, 1, 1]);
60        let array = af::Array::new(&buffer, dims);
61        Self { array, ndim }
62    }
63
64    /// Constructs a new `Coords` from an [`ArrayExt`] of offsets with respect to the given shape.
65    ///
66    /// Panics: if `shape` is empty
67    pub fn from_offsets(offsets: Offsets, shape: &[u64]) -> Self {
68        assert!(!shape.is_empty());
69
70        let ndim = shape.len() as u64;
71        let coord_bounds = coord_bounds(shape);
72
73        let dims = af::Dim4::new(&[1, ndim, 1, 1]);
74        let af_coord_bounds: af::Array<u64> = af::Array::new(&coord_bounds, dims);
75        let af_shape: af::Array<u64> = af::Array::new(&shape, dims);
76
77        let offsets = af::div(offsets.deref(), &af_coord_bounds, true);
78        let coords = af::modulo(&offsets, &af_shape, true);
79        let array = af::transpose(&coords, false);
80
81        Self {
82            array,
83            ndim: shape.len(),
84        }
85    }
86
87    /// Constructs a new `Coords` from a [`Stream`] of [`Coord`]s.
88    ///
89    /// Panics: if any [`Coord`] has a length other than `ndim`, or if `ndim` is zero
90    pub async fn from_stream<S: Stream<Item = Coord> + Unpin>(
91        mut source: S,
92        ndim: usize,
93        size_hint: Option<usize>,
94    ) -> Self {
95        assert!(ndim > 0);
96
97        let mut num_coords = 0;
98        let mut buffer = if let Some(size) = size_hint {
99            Vec::with_capacity(size)
100        } else {
101            Vec::new()
102        };
103
104        while let Some(coord) = source.next().await {
105            assert_eq!(coord.len(), ndim);
106            buffer.extend(coord);
107            num_coords += 1;
108        }
109
110        let array = af::Array::new(&buffer, af::Dim4::new(&[ndim as u64, num_coords, 1, 1]));
111
112        Self { array, ndim }
113    }
114
115    /// Constructs a new `Coords` from a [`TryStream`] of `Coord`s.
116    ///
117    /// Panics: if any [`Coord`] has a length other than `ndim`
118    pub async fn try_from_stream<E, S: TryStream<Ok = Coord, Error = E> + Unpin>(
119        mut source: S,
120        ndim: usize,
121        size_hint: Option<usize>,
122    ) -> Result<Self, E> {
123        let mut num_coords = 0;
124        let mut buffer = if let Some(size) = size_hint {
125            Vec::with_capacity(size)
126        } else {
127            Vec::new()
128        };
129
130        while let Some(coord) = source.try_next().await? {
131            assert_eq!(coord.len(), ndim);
132            buffer.extend(coord);
133            num_coords += 1;
134        }
135
136        let array = af::Array::new(&buffer, af::Dim4::new(&[ndim as u64, num_coords, 1, 1]));
137
138        Ok(Self { array, ndim })
139    }
140
141    /// Return `true` if the number of coordinates in these `Coords` is zero.
142    pub fn is_empty(&self) -> bool {
143        self.array.elements() == 0
144    }
145
146    /// Return `true` if these `Coords` are in sorted order with respect to the given `shape`.
147    pub fn is_sorted(&self, shape: &[u64]) -> bool {
148        self.to_offsets(shape).is_sorted()
149    }
150
151    /// Return the number of coordinates stored in these `Coords`.
152    pub fn len(&self) -> usize {
153        self.dims()[1] as usize
154    }
155
156    /// Return the number of dimensions of these `Coords`.
157    pub fn ndim(&self) -> usize {
158        self.ndim
159    }
160
161    fn last(&self) -> Coord {
162        let i = (self.len() - 1) as i32;
163        let dim0 = af::seq!(0, (self.ndim - 1) as i32, 1);
164        let dim1 = af::seq!(i, i, 1);
165        let slice = af::index(self, &[dim0, dim1]);
166        let mut first = vec![0; self.ndim];
167        slice.host(&mut first);
168        first
169    }
170
171    fn append(&self, other: &Coords) -> Self {
172        assert_eq!(self.ndim, other.ndim);
173
174        let array = af::join(1, self, other);
175        Self {
176            array,
177            ndim: self.ndim,
178        }
179    }
180
181    fn split(&self, at: usize) -> (Self, Self) {
182        assert!(at > 0);
183        assert!(at < self.len());
184
185        let left = af::seq!(0, (at - 1) as i32, 1);
186        let right = af::seq!(at as i32, (self.len() - 1) as i32, 1);
187
188        let left = af::index(self, &[af::seq!(), left]);
189        let right = af::index(self, &[af::seq!(), right]);
190
191        (
192            Self {
193                array: left,
194                ndim: self.ndim,
195            },
196            Self {
197                array: right,
198                ndim: self.ndim,
199            },
200        )
201    }
202
203    fn split_lte(&self, lt: &[u64], shape: &[u64]) -> (Option<Self>, Option<Self>) {
204        assert_eq!(lt.len(), self.ndim);
205        assert_eq!(shape.len(), self.ndim);
206
207        let coord_bounds = coord_bounds(shape);
208        let pivot = coord_to_offset(lt, &coord_bounds);
209        let pivot = af::Array::new(&[pivot], af::dim4!(1));
210        let offsets = self.to_offsets(shape);
211        let left = af::le(offsets.deref(), &pivot, true);
212        let pivot = af::sum_all(&left).0;
213
214        if pivot == 0 {
215            return (None, Some(self.clone()));
216        } else if pivot == self.len() as u32 {
217            return (Some(self.clone()), None);
218        }
219
220        let (l, r) = self.split(pivot as usize);
221
222        debug_assert_eq!(l.array.dims()[0], self.ndim as u64);
223        debug_assert_eq!(r.array.dims()[0], self.ndim as u64);
224
225        (Some(l), Some(r))
226    }
227
228    fn sorted(&self) -> Self {
229        let array = af::sort(self, 2, true);
230        Self {
231            array,
232            ndim: self.ndim,
233        }
234    }
235
236    fn unique(&self, shape: &[u64]) -> Self {
237        let offsets = self.to_offsets(shape);
238        let offsets = af::set_unique(offsets.deref(), true);
239        Self::from_offsets(offsets.into(), shape)
240    }
241
242    /// Return a copy of these `Coords` without the specified axis.
243    ///
244    /// Panics: if there is no dimension at `axis`
245    pub fn contract_dim(&self, axis: usize) -> Self {
246        assert!(axis < self.ndim);
247
248        let mut index: Vec<usize> = (0..self.ndim).collect();
249        index.remove(axis);
250
251        self.get(&index)
252    }
253
254    /// Return the `Coords` of source elements to reduce along the given axis.
255    ///
256    /// Panics: if `reduce_axis <= source_shape.len()`
257    pub fn expand(&self, source_shape: &[u64], reduce_axis: usize) -> Self {
258        let ndim = self.ndim + 1;
259
260        assert_eq!(source_shape.len(), ndim);
261        assert!(reduce_axis <= ndim);
262
263        let reduce_dim = source_shape[reduce_axis];
264
265        let dims = af::dim4!(1, reduce_dim);
266        let reduced = af::range(dims, 1);
267
268        let reduce_index = vec![reduce_axis as u64];
269
270        let index: Vec<u64> = (0..self.ndim)
271            .map(|x| if x < reduce_axis { x } else { x + 1 })
272            .map(|x| x as u64)
273            .collect();
274
275        let tile_dims = af::dim4!(1, reduce_dim);
276        let source_coord_dims = af::dim4!(ndim as u64, reduce_dim);
277
278        let mut expanded = Vec::with_capacity(self.len());
279        for i in 0..self.dims()[1] {
280            let i = i as i32;
281            let seqs = &[af::seq!(), af::seq!(i, i, 1)];
282            let coord = af::index(&self.array, seqs);
283            let coord = af::tile(&coord, tile_dims);
284
285            let mut expanded_coord = af::constant(0, source_coord_dims);
286            index_set(&mut expanded_coord, &index, &coord);
287            index_set(&mut expanded_coord, &reduce_index, &reduced);
288            expanded.push(expanded_coord);
289        }
290
291        Self {
292            array: af::join_many(1, expanded.iter().collect()),
293            ndim,
294        }
295    }
296
297    /// Return a copy of these `Coords` with a new dimension at the given axis.
298    ///
299    /// Panics: if `axis` is greater than `self.ndim()`
300    pub fn expand_dim(&self, axis: usize) -> Self {
301        assert!(axis <= self.ndim);
302
303        let ndim = self.ndim + 1;
304        let dims = af::Dim4::new(&[ndim as u64, self.dims()[1], 1, 1]);
305        let mut expanded = af::constant(0, dims);
306
307        let index: Vec<u64> = (0..self.ndim())
308            .map(|x| if x < axis { x } else { x + 1 })
309            .map(|x| x as u64)
310            .collect();
311
312        index_set(&mut expanded, &index, self);
313
314        Self {
315            array: expanded,
316            ndim,
317        }
318    }
319
320    /// Return these `Coords` as flipped around `axis` with respect to the given `shape`.
321    ///
322    /// E.g. flipping axis 1 of coordinate `[0, 1, 2]` in shape `[5, 5, 5]` produces `[0, 4, 2]`.
323    ///
324    /// Panics: if `self.ndim() != shape.len()`
325    pub fn flip(self, shape: &[u64], axis: usize) -> Self {
326        assert_eq!(self.ndim, shape.len());
327
328        let mut mask = vec![0i64; self.ndim()];
329        mask[axis] = (shape[axis] - 1) as i64;
330        let mask = af::Array::new(&mask, af::Dim4::new(&[self.ndim() as u64, 1, 1, 1]));
331
332        let coords: af::Array<i64> = self.array.cast();
333        let flipped = af::sub(&mask, &coords, true);
334
335        Self {
336            array: af::abs(&flipped).cast(),
337            ndim: self.ndim,
338        }
339    }
340
341    /// Transform the coordinate basis of these `Coords` from a source tensor to a slice.
342    pub fn slice(
343        &self,
344        shape: &[u64],
345        elided: &HashMap<usize, u64>,
346        offset: &HashMap<usize, u64>,
347    ) -> Self {
348        let ndim = shape.len();
349        let mut offsets = Vec::with_capacity(ndim);
350        let mut index = Vec::with_capacity(ndim);
351        for x in 0..self.ndim {
352            if elided.contains_key(&x) {
353                continue;
354            }
355
356            let offset = offset.get(&x).unwrap_or(&0);
357            offsets.push(*offset);
358            index.push(x);
359        }
360
361        let offsets = af::Array::new(&offsets, af::dim4!(offsets.len() as u64));
362        let array = af::sub(self.get(&index).deref(), &offsets, true);
363        Self { array, ndim }
364    }
365
366    /// Transpose these `Coords` according to the given `permutation`.
367    ///
368    /// If no permutation is given, the coordinate axes will be inverted.
369    pub fn transpose<P: AsRef<[usize]>>(&self, permutation: Option<P>) -> Coords {
370        if let Some(permutation) = permutation {
371            self.get(permutation.as_ref())
372        } else {
373            let array = af::transpose(&self.array, false);
374            let ndim = self.ndim;
375            Self { array, ndim }
376        }
377    }
378
379    /// Invert the given broadcast of these `Coords`.
380    ///
381    /// Panics: if `source_shape` and `broadcast` are not the same length.
382    pub fn unbroadcast(&self, source_shape: &[u64], broadcast: &[bool]) -> Coords {
383        assert_eq!(self.ndim(), broadcast.len());
384
385        let offset = self.ndim() - source_shape.len();
386        let mut coords = Self::empty(source_shape, self.len());
387        if source_shape.is_empty() || broadcast.iter().all(|b| *b) {
388            return coords;
389        }
390
391        let axes: Vec<usize> = broadcast
392            .iter()
393            .enumerate()
394            .filter_map(|(x, b)| if *b { None } else { Some(x) })
395            .collect();
396
397        let unbroadcasted = self.get(&axes);
398
399        let axes: Vec<usize> = broadcast
400            .iter()
401            .enumerate()
402            .filter_map(|(x, b)| if *b { None } else { Some(x - offset) })
403            .collect();
404
405        coords.set(&axes, &unbroadcasted);
406
407        coords
408    }
409
410    /// Transform the coordinate basis of these `Coords` from a slice to a source tensor.
411    ///
412    /// Panics: if `source_shape.len() - self.ndim()` does not match `elided.len()`
413    pub fn unslice(
414        &self,
415        source_shape: &[u64],
416        elided: &HashMap<usize, u64>,
417        offset: &HashMap<usize, u64>,
418    ) -> Self {
419        let ndim = source_shape.len();
420        let mut axes = Vec::with_capacity(self.ndim);
421        let mut unsliced = vec![0; source_shape.len()];
422        let mut offsets = vec![0; source_shape.len()];
423        for x in 0..ndim {
424            if let Some(elide) = elided.get(&x) {
425                unsliced[x] = *elide;
426            } else {
427                axes.push(x as u64);
428                offsets[x] = *offset.get(&x).unwrap_or(&0);
429            }
430        }
431        assert_eq!(axes.len(), self.ndim);
432
433        let unsliced = af::Array::new(&unsliced, af::dim4!(ndim as u64));
434        let tile_dims = af::Dim4::new(&[1, self.len() as u64, 1, 1]);
435        let mut unsliced = af::tile(&unsliced, tile_dims);
436        index_set(&mut unsliced, &axes, self);
437
438        let offsets = af::Array::new(&offsets, af::dim4!(ndim as u64));
439        let offsets = af::tile(&offsets, tile_dims);
440
441        Self {
442            array: unsliced + offsets,
443            ndim,
444        }
445    }
446
447    /// Construct a new `Coords` from the selected indices.
448    ///
449    /// Panics: if any index is out of bounds
450    pub fn get(&self, axes: &[usize]) -> Self {
451        let axes: Vec<u64> = axes
452            .iter()
453            .map(|x| {
454                assert!(x < &self.ndim);
455                *x as u64
456            })
457            .collect();
458
459        let array = index_get(self, &axes);
460        Self {
461            array,
462            ndim: axes.len(),
463        }
464    }
465
466    /// Update these `Coords` by writing the given `value` at the given `index`.
467    ///
468    /// Panics: if any index is out of bounds, or if `value.len()` does not match `self.len()`
469    pub fn set(&mut self, axes: &[usize], value: &Self) {
470        let axes: Vec<u64> = axes
471            .iter()
472            .map(|x| {
473                assert!(x < &self.ndim);
474                *x as u64
475            })
476            .collect();
477
478        index_set(self, &axes, value)
479    }
480
481    /// Return these `Coords` as [`Offsets`] with respect to the given shape.
482    ///
483    /// Panics: if `shape.len()` does not equal `self.ndim()`
484    pub fn to_offsets(&self, shape: &[u64]) -> ArrayExt<u64> {
485        let ndim = shape.len();
486        assert_eq!(self.ndim, ndim);
487
488        let coord_bounds = coord_bounds(shape);
489        let af_coord_bounds: af::Array<u64> = af::Array::new(&coord_bounds, af::dim4!(ndim as u64));
490
491        let offsets = af::mul(&self.array, &af_coord_bounds, true);
492        let offsets = af::sum(&offsets, 0).into();
493        af::moddims(&offsets, af::dim4!(offsets.elements() as u64)).into()
494    }
495
496    /// Return a list of [`Coord`]s from these `Coords`.
497    ///
498    /// Panics: if the given number of dimensions does not fit the set of coordinates
499    pub fn to_vec(&self) -> Vec<Coord> {
500        assert_eq!(self.array.elements() % self.ndim, 0);
501
502        let mut to_vec = vec![0u64; self.array.elements()];
503        self.array.host(&mut to_vec);
504
505        to_vec
506            .chunks(self.ndim)
507            .map(|coord| coord.to_vec())
508            .collect()
509    }
510
511    /// Convert these `Coords` into a list of [`Coord`]s.
512    ///
513    /// Panics: if the given number of dimensions does not fit the set of coordinates.
514    pub fn into_vec(self) -> Vec<Coord> {
515        self.to_vec()
516    }
517}
518
519impl Deref for Coords {
520    type Target = af::Array<u64>;
521
522    fn deref(&self) -> &Self::Target {
523        &self.array
524    }
525}
526
527impl DerefMut for Coords {
528    fn deref_mut(&mut self) -> &mut Self::Target {
529        &mut self.array
530    }
531}
532
533impl PartialEq for Coords {
534    fn eq(&self, other: &Self) -> bool {
535        if self.ndim == other.ndim {
536            let batch = self.array.dims() != other.array.dims();
537            af::all_true_all(&af::eq(&self.array, &other.array, batch)).0
538        } else {
539            false
540        }
541    }
542}
543
544impl fmt::Debug for Coords {
545    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
546        write!(f, "a block of {} coordinates", self.len())
547    }
548}
549
550/// A [`Stream`] of [`Coords`], as constructed from an input stream of [`Coord`]s.
551pub struct CoordBlocks<S> {
552    source: Fuse<S>,
553    ndim: usize,
554    block_size: usize,
555    buffer: Vec<u64>,
556}
557
558impl<E, S: Stream<Item = Result<Coord, E>>> CoordBlocks<S> {
559    /// Construct a new `CoordBlocks`.
560    ///
561    /// Panics: if `ndim == 0`
562    pub fn new(source: S, ndim: usize, block_size: usize) -> Self {
563        assert!(ndim > 0);
564
565        Self {
566            source: source.fuse(),
567            ndim,
568            block_size,
569            buffer: Vec::with_capacity(ndim * block_size),
570        }
571    }
572
573    fn consume_buffer(&mut self) -> Coords {
574        assert_eq!(self.buffer.len() % self.ndim, 0);
575
576        let ndim = self.ndim as u64;
577        let num_coords = (self.buffer.len() / self.ndim) as u64;
578        let dims = af::Dim4::new(&[ndim, num_coords, 1, 1]);
579        let coords = Coords {
580            array: af::Array::new(&self.buffer, dims),
581            ndim: self.ndim,
582        };
583
584        self.buffer.clear();
585        coords
586    }
587}
588
589impl<E, S: Stream<Item = Result<Coord, E>> + Unpin> Stream for CoordBlocks<S> {
590    type Item = Result<Coords, E>;
591
592    fn poll_next(mut self: Pin<&mut Self>, cxt: &mut Context<'_>) -> Poll<Option<Self::Item>> {
593        Poll::Ready(loop {
594            match ready!(Pin::new(&mut self.source).poll_next(cxt)) {
595                Some(Ok(coord)) => {
596                    assert_eq!(coord.len(), self.ndim);
597                    self.buffer.extend(coord);
598
599                    if self.buffer.len() == (self.block_size * self.ndim) {
600                        break Some(Ok(self.consume_buffer()));
601                    }
602                }
603                Some(Err(cause)) => break Some(Err(cause)),
604                None if self.buffer.is_empty() => break None,
605                None => break Some(Ok(self.consume_buffer())),
606            }
607        })
608    }
609}
610
611impl<E, S: Stream<Item = Result<Coord, E>> + Unpin> FusedStream for CoordBlocks<S> {
612    fn is_terminated(&self) -> bool {
613        self.source.is_terminated() && self.buffer.is_empty()
614    }
615}
616
617/// Stream for merging two sorted [`CoordBlocks`] streams.
618///
619/// The behavior of `CoordMerge` is undefined if the input streams are not sorted.
620#[pin_project]
621pub struct CoordMerge<L, R> {
622    #[pin]
623    left: Fuse<L>,
624
625    #[pin]
626    right: Fuse<R>,
627
628    pending_left: Option<Coords>,
629    pending_right: Option<Coords>,
630    buffer: Option<Coords>,
631    block_size: usize,
632    shape: Vec<u64>,
633}
634
635impl<L: Stream, R: Stream> CoordMerge<L, R> {
636    /// Construct a new `CoordMerge` stream.
637    ///
638    /// Panics: if the dimensions of `left`, `right`, and `shape` don't match,
639    /// or if `block_size` is zero.
640    pub fn new(left: L, right: R, shape: Vec<u64>, block_size: usize) -> Self {
641        assert!(block_size > 0);
642
643        Self {
644            left: left.fuse(),
645            right: right.fuse(),
646
647            shape,
648            block_size,
649
650            pending_left: None,
651            pending_right: None,
652            buffer: None,
653        }
654    }
655}
656
657impl<E, L, R> Stream for CoordMerge<L, R>
658where
659    L: Stream<Item = Result<Coords, E>> + Unpin,
660    R: Stream<Item = Result<Coords, E>> + Unpin,
661{
662    type Item = Result<Coords, E>;
663
664    fn poll_next(self: Pin<&mut Self>, cxt: &mut Context<'_>) -> Poll<Option<Self::Item>> {
665        let mut this = self.project();
666
667        Poll::Ready(loop {
668            if this.pending_left.is_none() && !this.left.is_terminated() {
669                match ready!(this.left.as_mut().poll_next(cxt)) {
670                    Some(Ok(coords)) => {
671                        assert_eq!(coords.ndim(), this.shape.len());
672                        *this.pending_left = Some(coords)
673                    }
674                    Some(Err(cause)) => return Poll::Ready(Some(Err(cause))),
675                    None => {}
676                }
677            }
678
679            if this.pending_right.is_none() && !this.right.is_terminated() {
680                match ready!(this.right.as_mut().poll_next(cxt)) {
681                    Some(Ok(coords)) => {
682                        assert_eq!(coords.ndim(), this.shape.len());
683                        *this.pending_right = Some(coords)
684                    }
685                    Some(Err(cause)) => return Poll::Ready(Some(Err(cause))),
686                    None => {}
687                }
688            }
689
690            match (&mut *this.pending_left, &mut *this.pending_right) {
691                (Some(l), Some(r)) if l.last() < r.last() => {
692                    let (r, r_pending) = r.split_lte(&l.last(), this.shape);
693                    *this.pending_right = r_pending;
694
695                    if let Some(r) = r {
696                        create_or_append(this.buffer, r);
697                    }
698
699                    let mut l = None;
700                    mem::swap(this.pending_left, &mut l);
701                    create_or_append(this.buffer, l.unwrap());
702                }
703                (Some(l), Some(r)) if r.last() < l.last() => {
704                    let (l, l_pending) = l.split_lte(&r.last(), this.shape);
705                    *this.pending_left = l_pending;
706
707                    if let Some(l) = l {
708                        create_or_append(this.buffer, l);
709                    }
710
711                    let mut r = None;
712                    mem::swap(this.pending_right, &mut r);
713                    create_or_append(this.buffer, r.unwrap());
714                }
715                (Some(l), Some(r)) => {
716                    assert_eq!(l.last(), r.last());
717
718                    let mut l = None;
719                    mem::swap(this.pending_left, &mut l);
720                    create_or_append(this.buffer, l.unwrap());
721
722                    let mut r = None;
723                    mem::swap(this.pending_right, &mut r);
724                    create_or_append(this.buffer, r.unwrap());
725                }
726                (Some(_), None) => {
727                    let mut new_l = None;
728                    mem::swap(this.pending_left, &mut new_l);
729                    create_or_append(this.buffer, new_l.unwrap());
730                }
731                (_, Some(_)) => {
732                    let mut new_r = None;
733                    mem::swap(this.pending_right, &mut new_r);
734                    create_or_append(this.buffer, new_r.unwrap());
735                }
736                (None, None) if this.buffer.is_some() => {
737                    let coords = this.buffer.as_ref().unwrap().sorted();
738                    *this.buffer = None;
739                    break Some(Ok(coords));
740                }
741                (None, None) => break None,
742            }
743
744            if let Some(buffer) = this.buffer {
745                if buffer.len() == *this.block_size {
746                    let mut coords = None;
747                    mem::swap(&mut coords, this.buffer);
748                    break Some(Ok(coords.unwrap().sorted()));
749                } else if buffer.len() > *this.block_size {
750                    let coords = buffer.sorted();
751                    let (coords, buffer) = coords.split(*this.block_size);
752                    *this.buffer = Some(buffer);
753                    break Some(Ok(coords));
754                }
755            }
756        })
757    }
758}
759
760/// Return only the unique coordinates from a sorted stream of `Coords`.
761///
762/// Behavior is undefined if the input stream is not sorted.
763#[pin_project]
764pub struct CoordUnique<S> {
765    #[pin]
766    source: Fuse<S>,
767    buffer: Option<Coords>,
768    shape: Vec<u64>,
769    block_size: usize,
770}
771
772impl<S: Stream> CoordUnique<S> {
773    /// Construct a new `CoordUnique` stream from a sorted stream of `Coords`.
774    pub fn new(source: S, shape: Vec<u64>, block_size: usize) -> Self {
775        Self {
776            source: source.fuse(),
777            buffer: None,
778            shape,
779            block_size,
780        }
781    }
782}
783
784impl<E, S: Stream<Item = Result<Coords, E>>> Stream for CoordUnique<S> {
785    type Item = Result<Coords, E>;
786
787    fn poll_next(self: Pin<&mut Self>, cxt: &mut Context<'_>) -> Poll<Option<Self::Item>> {
788        let mut this = self.project();
789
790        Poll::Ready(loop {
791            match ready!(this.source.as_mut().poll_next(cxt)) {
792                Some(Ok(block)) => {
793                    let buffer = if let Some(buffer) = this.buffer {
794                        buffer.append(&block).unique(this.shape)
795                    } else {
796                        block.unique(this.shape)
797                    };
798
799                    *this.buffer = Some(buffer);
800                }
801                Some(Err(cause)) => break Some(Err(cause)),
802                None if this.buffer.is_some() => {
803                    let mut buffer = None;
804                    mem::swap(this.buffer, &mut buffer);
805                    break buffer.map(Ok);
806                }
807                None => break None,
808            }
809
810            if let Some(buffer) = this.buffer {
811                if buffer.len() > *this.block_size {
812                    let (block, buffer) = buffer.split(*this.block_size);
813                    *this.buffer = Some(buffer);
814                    break Some(Ok(block));
815                }
816            }
817        })
818    }
819}
820
821#[inline]
822fn create_or_append(coords: &mut Option<Coords>, to_append: Coords) {
823    if to_append.is_empty() {
824        return;
825    }
826
827    assert!(to_append.dims()[0] > 0);
828
829    *coords = match coords {
830        Some(coords) => Some(coords.append(&to_append)),
831        None => Some(to_append),
832    };
833}
834
835#[inline]
836/// Convert a coordinate to a linear offset.
837pub fn coord_to_offset(coord: &[u64], coord_bounds: &[u64]) -> u64 {
838    coord_bounds
839        .iter()
840        .zip(coord.iter())
841        .map(|(d, x)| d * x)
842        .sum()
843}
844
845fn index_get(subject: &af::Array<u64>, index: &[u64]) -> af::Array<u64> {
846    let len = subject.dims()[1];
847    let index = af::Array::new(index, af::dim4!(index.len() as u64));
848    let seq4gen = af::seq!(0, (len - 1) as i32, 1);
849    let mut indexer = af::Indexer::default();
850    indexer.set_index(&index, 0, None);
851    indexer.set_index(&seq4gen, 1, Some(true));
852
853    af::index_gen(subject, indexer)
854}
855
856fn index_set(subject: &mut af::Array<u64>, index: &[u64], value: &af::Array<u64>) {
857    debug_assert!(value.dims()[0] == index.len() as u64);
858    debug_assert!(value.dims()[1] == subject.dims()[1]);
859
860    let len = subject.dims()[1];
861    let index = af::Array::new(index, af::dim4!(index.len() as u64));
862    if len == 1 {
863        let mut indexer = af::Indexer::default();
864        indexer.set_index(&index, 0, Some(false));
865        af::assign_gen(subject, &indexer, value);
866    } else {
867        let seq4gen = af::seq!(0, (len - 1) as i32, 1);
868        let mut indexer = af::Indexer::default();
869        indexer.set_index(&index, 0, None);
870        indexer.set_index(&seq4gen, 1, Some(true));
871
872        af::assign_gen(subject, &indexer, value);
873    }
874}
875
876#[cfg(test)]
877mod tests {
878    use super::*;
879
880    #[test]
881    fn test_to_coords() {
882        let offsets = ArrayExt::range(0, 5);
883        let coords = Coords::from_offsets(offsets, &[5, 2]);
884        assert_eq!(
885            coords.into_vec(),
886            vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1], vec![2, 0],]
887        )
888    }
889
890    #[test]
891    fn test_merge_helpers() {
892        let coord_vec = vec![
893            vec![0, 0, 0],
894            vec![0, 0, 1],
895            vec![0, 1, 0],
896            vec![1, 0, 0],
897            vec![1, 1, 1],
898        ];
899        let coords = Coords::from_iter(coord_vec.to_vec(), 3);
900
901        assert_eq!(&coords.last(), coord_vec.last().unwrap());
902
903        let (l, r) = coords.split(1);
904        assert_eq!(l.to_vec(), &coord_vec[..1]);
905        assert_eq!(r.to_vec(), &coord_vec[1..]);
906
907        let (l, r) = coords.split_lte(&[0, 1, 0], &[2, 2, 2]);
908        assert_eq!(l.as_ref().expect("left").to_vec(), &coord_vec[..3]);
909        assert_eq!(r.as_ref().expect("right").to_vec(), &coord_vec[3..]);
910
911        let joined = l.expect("left").append(r.as_ref().expect("right"));
912        assert_eq!(joined.to_vec(), coords.to_vec());
913
914        assert_eq!(coords.to_vec(), coords.sorted().to_vec());
915    }
916
917    #[test]
918    fn test_unique_helpers() {
919        let coord_vec = vec![
920            vec![0, 0, 0],
921            vec![0, 0, 1],
922            vec![0, 0, 1],
923            vec![0, 1, 0],
924            vec![1, 0, 0],
925        ];
926
927        let coords = Coords::from_iter(coord_vec.to_vec(), 3);
928
929        let expected = vec![vec![0, 0, 0], vec![0, 0, 1], vec![0, 1, 0], vec![1, 0, 0]];
930        assert_eq!(coords.unique(&[2, 2, 2]).to_vec(), expected);
931    }
932
933    #[test]
934    fn test_get_and_set() {
935        let source = Coords::from_iter(vec![vec![0, 1, 2], vec![3, 4, 5], vec![6, 7, 8]], 3);
936
937        let value = source.get(&[1, 2]);
938
939        assert_eq!(value.ndim(), 2);
940        assert_eq!(value.to_vec(), vec![vec![1, 2], vec![4, 5], vec![7, 8]]);
941
942        let mut dest = Coords::empty(&[10, 15, 20], 3);
943        dest.set(&[0, 2], &value);
944
945        assert_eq!(dest.to_vec(), vec![[1, 0, 2], [4, 0, 5], [7, 0, 8],])
946    }
947
948    #[test]
949    fn test_unbroadcast() {
950        let coords = Coords::from_iter(vec![vec![8, 15, 2, 1, 10, 3], vec![9, 16, 3, 4, 11, 6]], 6);
951        let actual = coords.unbroadcast(&[5, 1, 1, 10], &[true, true, false, true, true, false]);
952        assert_eq!(actual.to_vec(), vec![vec![2, 0, 0, 3], vec![3, 0, 0, 6]]);
953    }
954
955    #[test]
956    fn test_reduce() {
957        let coords = Coords::from_iter(vec![vec![0, 1], vec![1, 2]], 2);
958        let actual = coords.expand(&[2, 3, 5], 0);
959        assert_eq!(
960            actual.to_vec(),
961            vec![vec![0, 0, 1], vec![1, 0, 1], vec![0, 1, 2], vec![1, 1, 2],]
962        );
963
964        let coords = Coords::from_iter(vec![vec![0, 1], vec![1, 2]], 2);
965        let actual = coords.expand(&[3, 2, 5], 1);
966        assert_eq!(
967            actual.to_vec(),
968            vec![vec![0, 0, 1], vec![0, 1, 1], vec![1, 0, 2], vec![1, 1, 2],]
969        );
970
971        let coords = Coords::from_iter(vec![vec![0, 1], vec![1, 2]], 2);
972        let actual = coords.expand(&[3, 5, 2], 2);
973        assert_eq!(
974            actual.to_vec(),
975            vec![vec![0, 1, 0], vec![0, 1, 1], vec![1, 2, 0], vec![1, 2, 1],]
976        );
977    }
978}