1use std::collections::HashMap;
2use std::fmt;
3use std::iter::IntoIterator;
4use std::mem;
5use std::ops::{Deref, DerefMut};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use arrayfire as af;
10use futures::ready;
11use futures::stream::{Fuse, FusedStream, Stream, StreamExt, TryStream, TryStreamExt};
12use pin_project::pin_project;
13
14use super::{coord_bounds, ArrayExt};
15
16pub type Coord = Vec<u64>;
18
19pub type Offsets = ArrayExt<u64>;
21
22#[derive(Clone)]
26pub struct Coords {
27 array: af::Array<u64>,
28 ndim: usize,
29}
30
31impl Coords {
32 pub fn empty(shape: &[u64], size: usize) -> Self {
36 assert!(!shape.is_empty());
37 assert!(size > 0);
38
39 let ndim = shape.len();
40 let dims = af::Dim4::new(&[ndim as u64, size as u64, 1, 1]);
41 let array = af::constant(0u64, dims);
42 assert_eq!(array.dims(), dims);
43 Self { array, ndim }
44 }
45
46 pub fn from_iter<I: IntoIterator<Item = Coord>>(iter: I, ndim: usize) -> Self {
50 assert!(ndim > 0);
51
52 let buffer: Vec<u64> = iter
53 .into_iter()
54 .inspect(|coord| assert_eq!(coord.len(), ndim))
55 .flatten()
56 .collect();
57
58 let num_coords = buffer.len() / ndim;
59 let dims = af::Dim4::new(&[ndim as u64, num_coords as u64, 1, 1]);
60 let array = af::Array::new(&buffer, dims);
61 Self { array, ndim }
62 }
63
64 pub fn from_offsets(offsets: Offsets, shape: &[u64]) -> Self {
68 assert!(!shape.is_empty());
69
70 let ndim = shape.len() as u64;
71 let coord_bounds = coord_bounds(shape);
72
73 let dims = af::Dim4::new(&[1, ndim, 1, 1]);
74 let af_coord_bounds: af::Array<u64> = af::Array::new(&coord_bounds, dims);
75 let af_shape: af::Array<u64> = af::Array::new(&shape, dims);
76
77 let offsets = af::div(offsets.deref(), &af_coord_bounds, true);
78 let coords = af::modulo(&offsets, &af_shape, true);
79 let array = af::transpose(&coords, false);
80
81 Self {
82 array,
83 ndim: shape.len(),
84 }
85 }
86
87 pub async fn from_stream<S: Stream<Item = Coord> + Unpin>(
91 mut source: S,
92 ndim: usize,
93 size_hint: Option<usize>,
94 ) -> Self {
95 assert!(ndim > 0);
96
97 let mut num_coords = 0;
98 let mut buffer = if let Some(size) = size_hint {
99 Vec::with_capacity(size)
100 } else {
101 Vec::new()
102 };
103
104 while let Some(coord) = source.next().await {
105 assert_eq!(coord.len(), ndim);
106 buffer.extend(coord);
107 num_coords += 1;
108 }
109
110 let array = af::Array::new(&buffer, af::Dim4::new(&[ndim as u64, num_coords, 1, 1]));
111
112 Self { array, ndim }
113 }
114
115 pub async fn try_from_stream<E, S: TryStream<Ok = Coord, Error = E> + Unpin>(
119 mut source: S,
120 ndim: usize,
121 size_hint: Option<usize>,
122 ) -> Result<Self, E> {
123 let mut num_coords = 0;
124 let mut buffer = if let Some(size) = size_hint {
125 Vec::with_capacity(size)
126 } else {
127 Vec::new()
128 };
129
130 while let Some(coord) = source.try_next().await? {
131 assert_eq!(coord.len(), ndim);
132 buffer.extend(coord);
133 num_coords += 1;
134 }
135
136 let array = af::Array::new(&buffer, af::Dim4::new(&[ndim as u64, num_coords, 1, 1]));
137
138 Ok(Self { array, ndim })
139 }
140
141 pub fn is_empty(&self) -> bool {
143 self.array.elements() == 0
144 }
145
146 pub fn is_sorted(&self, shape: &[u64]) -> bool {
148 self.to_offsets(shape).is_sorted()
149 }
150
151 pub fn len(&self) -> usize {
153 self.dims()[1] as usize
154 }
155
156 pub fn ndim(&self) -> usize {
158 self.ndim
159 }
160
161 fn last(&self) -> Coord {
162 let i = (self.len() - 1) as i32;
163 let dim0 = af::seq!(0, (self.ndim - 1) as i32, 1);
164 let dim1 = af::seq!(i, i, 1);
165 let slice = af::index(self, &[dim0, dim1]);
166 let mut first = vec![0; self.ndim];
167 slice.host(&mut first);
168 first
169 }
170
171 fn append(&self, other: &Coords) -> Self {
172 assert_eq!(self.ndim, other.ndim);
173
174 let array = af::join(1, self, other);
175 Self {
176 array,
177 ndim: self.ndim,
178 }
179 }
180
181 fn split(&self, at: usize) -> (Self, Self) {
182 assert!(at > 0);
183 assert!(at < self.len());
184
185 let left = af::seq!(0, (at - 1) as i32, 1);
186 let right = af::seq!(at as i32, (self.len() - 1) as i32, 1);
187
188 let left = af::index(self, &[af::seq!(), left]);
189 let right = af::index(self, &[af::seq!(), right]);
190
191 (
192 Self {
193 array: left,
194 ndim: self.ndim,
195 },
196 Self {
197 array: right,
198 ndim: self.ndim,
199 },
200 )
201 }
202
203 fn split_lte(&self, lt: &[u64], shape: &[u64]) -> (Option<Self>, Option<Self>) {
204 assert_eq!(lt.len(), self.ndim);
205 assert_eq!(shape.len(), self.ndim);
206
207 let coord_bounds = coord_bounds(shape);
208 let pivot = coord_to_offset(lt, &coord_bounds);
209 let pivot = af::Array::new(&[pivot], af::dim4!(1));
210 let offsets = self.to_offsets(shape);
211 let left = af::le(offsets.deref(), &pivot, true);
212 let pivot = af::sum_all(&left).0;
213
214 if pivot == 0 {
215 return (None, Some(self.clone()));
216 } else if pivot == self.len() as u32 {
217 return (Some(self.clone()), None);
218 }
219
220 let (l, r) = self.split(pivot as usize);
221
222 debug_assert_eq!(l.array.dims()[0], self.ndim as u64);
223 debug_assert_eq!(r.array.dims()[0], self.ndim as u64);
224
225 (Some(l), Some(r))
226 }
227
228 fn sorted(&self) -> Self {
229 let array = af::sort(self, 2, true);
230 Self {
231 array,
232 ndim: self.ndim,
233 }
234 }
235
236 fn unique(&self, shape: &[u64]) -> Self {
237 let offsets = self.to_offsets(shape);
238 let offsets = af::set_unique(offsets.deref(), true);
239 Self::from_offsets(offsets.into(), shape)
240 }
241
242 pub fn contract_dim(&self, axis: usize) -> Self {
246 assert!(axis < self.ndim);
247
248 let mut index: Vec<usize> = (0..self.ndim).collect();
249 index.remove(axis);
250
251 self.get(&index)
252 }
253
254 pub fn expand(&self, source_shape: &[u64], reduce_axis: usize) -> Self {
258 let ndim = self.ndim + 1;
259
260 assert_eq!(source_shape.len(), ndim);
261 assert!(reduce_axis <= ndim);
262
263 let reduce_dim = source_shape[reduce_axis];
264
265 let dims = af::dim4!(1, reduce_dim);
266 let reduced = af::range(dims, 1);
267
268 let reduce_index = vec![reduce_axis as u64];
269
270 let index: Vec<u64> = (0..self.ndim)
271 .map(|x| if x < reduce_axis { x } else { x + 1 })
272 .map(|x| x as u64)
273 .collect();
274
275 let tile_dims = af::dim4!(1, reduce_dim);
276 let source_coord_dims = af::dim4!(ndim as u64, reduce_dim);
277
278 let mut expanded = Vec::with_capacity(self.len());
279 for i in 0..self.dims()[1] {
280 let i = i as i32;
281 let seqs = &[af::seq!(), af::seq!(i, i, 1)];
282 let coord = af::index(&self.array, seqs);
283 let coord = af::tile(&coord, tile_dims);
284
285 let mut expanded_coord = af::constant(0, source_coord_dims);
286 index_set(&mut expanded_coord, &index, &coord);
287 index_set(&mut expanded_coord, &reduce_index, &reduced);
288 expanded.push(expanded_coord);
289 }
290
291 Self {
292 array: af::join_many(1, expanded.iter().collect()),
293 ndim,
294 }
295 }
296
297 pub fn expand_dim(&self, axis: usize) -> Self {
301 assert!(axis <= self.ndim);
302
303 let ndim = self.ndim + 1;
304 let dims = af::Dim4::new(&[ndim as u64, self.dims()[1], 1, 1]);
305 let mut expanded = af::constant(0, dims);
306
307 let index: Vec<u64> = (0..self.ndim())
308 .map(|x| if x < axis { x } else { x + 1 })
309 .map(|x| x as u64)
310 .collect();
311
312 index_set(&mut expanded, &index, self);
313
314 Self {
315 array: expanded,
316 ndim,
317 }
318 }
319
320 pub fn flip(self, shape: &[u64], axis: usize) -> Self {
326 assert_eq!(self.ndim, shape.len());
327
328 let mut mask = vec![0i64; self.ndim()];
329 mask[axis] = (shape[axis] - 1) as i64;
330 let mask = af::Array::new(&mask, af::Dim4::new(&[self.ndim() as u64, 1, 1, 1]));
331
332 let coords: af::Array<i64> = self.array.cast();
333 let flipped = af::sub(&mask, &coords, true);
334
335 Self {
336 array: af::abs(&flipped).cast(),
337 ndim: self.ndim,
338 }
339 }
340
341 pub fn slice(
343 &self,
344 shape: &[u64],
345 elided: &HashMap<usize, u64>,
346 offset: &HashMap<usize, u64>,
347 ) -> Self {
348 let ndim = shape.len();
349 let mut offsets = Vec::with_capacity(ndim);
350 let mut index = Vec::with_capacity(ndim);
351 for x in 0..self.ndim {
352 if elided.contains_key(&x) {
353 continue;
354 }
355
356 let offset = offset.get(&x).unwrap_or(&0);
357 offsets.push(*offset);
358 index.push(x);
359 }
360
361 let offsets = af::Array::new(&offsets, af::dim4!(offsets.len() as u64));
362 let array = af::sub(self.get(&index).deref(), &offsets, true);
363 Self { array, ndim }
364 }
365
366 pub fn transpose<P: AsRef<[usize]>>(&self, permutation: Option<P>) -> Coords {
370 if let Some(permutation) = permutation {
371 self.get(permutation.as_ref())
372 } else {
373 let array = af::transpose(&self.array, false);
374 let ndim = self.ndim;
375 Self { array, ndim }
376 }
377 }
378
379 pub fn unbroadcast(&self, source_shape: &[u64], broadcast: &[bool]) -> Coords {
383 assert_eq!(self.ndim(), broadcast.len());
384
385 let offset = self.ndim() - source_shape.len();
386 let mut coords = Self::empty(source_shape, self.len());
387 if source_shape.is_empty() || broadcast.iter().all(|b| *b) {
388 return coords;
389 }
390
391 let axes: Vec<usize> = broadcast
392 .iter()
393 .enumerate()
394 .filter_map(|(x, b)| if *b { None } else { Some(x) })
395 .collect();
396
397 let unbroadcasted = self.get(&axes);
398
399 let axes: Vec<usize> = broadcast
400 .iter()
401 .enumerate()
402 .filter_map(|(x, b)| if *b { None } else { Some(x - offset) })
403 .collect();
404
405 coords.set(&axes, &unbroadcasted);
406
407 coords
408 }
409
410 pub fn unslice(
414 &self,
415 source_shape: &[u64],
416 elided: &HashMap<usize, u64>,
417 offset: &HashMap<usize, u64>,
418 ) -> Self {
419 let ndim = source_shape.len();
420 let mut axes = Vec::with_capacity(self.ndim);
421 let mut unsliced = vec![0; source_shape.len()];
422 let mut offsets = vec![0; source_shape.len()];
423 for x in 0..ndim {
424 if let Some(elide) = elided.get(&x) {
425 unsliced[x] = *elide;
426 } else {
427 axes.push(x as u64);
428 offsets[x] = *offset.get(&x).unwrap_or(&0);
429 }
430 }
431 assert_eq!(axes.len(), self.ndim);
432
433 let unsliced = af::Array::new(&unsliced, af::dim4!(ndim as u64));
434 let tile_dims = af::Dim4::new(&[1, self.len() as u64, 1, 1]);
435 let mut unsliced = af::tile(&unsliced, tile_dims);
436 index_set(&mut unsliced, &axes, self);
437
438 let offsets = af::Array::new(&offsets, af::dim4!(ndim as u64));
439 let offsets = af::tile(&offsets, tile_dims);
440
441 Self {
442 array: unsliced + offsets,
443 ndim,
444 }
445 }
446
447 pub fn get(&self, axes: &[usize]) -> Self {
451 let axes: Vec<u64> = axes
452 .iter()
453 .map(|x| {
454 assert!(x < &self.ndim);
455 *x as u64
456 })
457 .collect();
458
459 let array = index_get(self, &axes);
460 Self {
461 array,
462 ndim: axes.len(),
463 }
464 }
465
466 pub fn set(&mut self, axes: &[usize], value: &Self) {
470 let axes: Vec<u64> = axes
471 .iter()
472 .map(|x| {
473 assert!(x < &self.ndim);
474 *x as u64
475 })
476 .collect();
477
478 index_set(self, &axes, value)
479 }
480
481 pub fn to_offsets(&self, shape: &[u64]) -> ArrayExt<u64> {
485 let ndim = shape.len();
486 assert_eq!(self.ndim, ndim);
487
488 let coord_bounds = coord_bounds(shape);
489 let af_coord_bounds: af::Array<u64> = af::Array::new(&coord_bounds, af::dim4!(ndim as u64));
490
491 let offsets = af::mul(&self.array, &af_coord_bounds, true);
492 let offsets = af::sum(&offsets, 0).into();
493 af::moddims(&offsets, af::dim4!(offsets.elements() as u64)).into()
494 }
495
496 pub fn to_vec(&self) -> Vec<Coord> {
500 assert_eq!(self.array.elements() % self.ndim, 0);
501
502 let mut to_vec = vec![0u64; self.array.elements()];
503 self.array.host(&mut to_vec);
504
505 to_vec
506 .chunks(self.ndim)
507 .map(|coord| coord.to_vec())
508 .collect()
509 }
510
511 pub fn into_vec(self) -> Vec<Coord> {
515 self.to_vec()
516 }
517}
518
519impl Deref for Coords {
520 type Target = af::Array<u64>;
521
522 fn deref(&self) -> &Self::Target {
523 &self.array
524 }
525}
526
527impl DerefMut for Coords {
528 fn deref_mut(&mut self) -> &mut Self::Target {
529 &mut self.array
530 }
531}
532
533impl PartialEq for Coords {
534 fn eq(&self, other: &Self) -> bool {
535 if self.ndim == other.ndim {
536 let batch = self.array.dims() != other.array.dims();
537 af::all_true_all(&af::eq(&self.array, &other.array, batch)).0
538 } else {
539 false
540 }
541 }
542}
543
544impl fmt::Debug for Coords {
545 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
546 write!(f, "a block of {} coordinates", self.len())
547 }
548}
549
550pub struct CoordBlocks<S> {
552 source: Fuse<S>,
553 ndim: usize,
554 block_size: usize,
555 buffer: Vec<u64>,
556}
557
558impl<E, S: Stream<Item = Result<Coord, E>>> CoordBlocks<S> {
559 pub fn new(source: S, ndim: usize, block_size: usize) -> Self {
563 assert!(ndim > 0);
564
565 Self {
566 source: source.fuse(),
567 ndim,
568 block_size,
569 buffer: Vec::with_capacity(ndim * block_size),
570 }
571 }
572
573 fn consume_buffer(&mut self) -> Coords {
574 assert_eq!(self.buffer.len() % self.ndim, 0);
575
576 let ndim = self.ndim as u64;
577 let num_coords = (self.buffer.len() / self.ndim) as u64;
578 let dims = af::Dim4::new(&[ndim, num_coords, 1, 1]);
579 let coords = Coords {
580 array: af::Array::new(&self.buffer, dims),
581 ndim: self.ndim,
582 };
583
584 self.buffer.clear();
585 coords
586 }
587}
588
589impl<E, S: Stream<Item = Result<Coord, E>> + Unpin> Stream for CoordBlocks<S> {
590 type Item = Result<Coords, E>;
591
592 fn poll_next(mut self: Pin<&mut Self>, cxt: &mut Context<'_>) -> Poll<Option<Self::Item>> {
593 Poll::Ready(loop {
594 match ready!(Pin::new(&mut self.source).poll_next(cxt)) {
595 Some(Ok(coord)) => {
596 assert_eq!(coord.len(), self.ndim);
597 self.buffer.extend(coord);
598
599 if self.buffer.len() == (self.block_size * self.ndim) {
600 break Some(Ok(self.consume_buffer()));
601 }
602 }
603 Some(Err(cause)) => break Some(Err(cause)),
604 None if self.buffer.is_empty() => break None,
605 None => break Some(Ok(self.consume_buffer())),
606 }
607 })
608 }
609}
610
611impl<E, S: Stream<Item = Result<Coord, E>> + Unpin> FusedStream for CoordBlocks<S> {
612 fn is_terminated(&self) -> bool {
613 self.source.is_terminated() && self.buffer.is_empty()
614 }
615}
616
617#[pin_project]
621pub struct CoordMerge<L, R> {
622 #[pin]
623 left: Fuse<L>,
624
625 #[pin]
626 right: Fuse<R>,
627
628 pending_left: Option<Coords>,
629 pending_right: Option<Coords>,
630 buffer: Option<Coords>,
631 block_size: usize,
632 shape: Vec<u64>,
633}
634
635impl<L: Stream, R: Stream> CoordMerge<L, R> {
636 pub fn new(left: L, right: R, shape: Vec<u64>, block_size: usize) -> Self {
641 assert!(block_size > 0);
642
643 Self {
644 left: left.fuse(),
645 right: right.fuse(),
646
647 shape,
648 block_size,
649
650 pending_left: None,
651 pending_right: None,
652 buffer: None,
653 }
654 }
655}
656
657impl<E, L, R> Stream for CoordMerge<L, R>
658where
659 L: Stream<Item = Result<Coords, E>> + Unpin,
660 R: Stream<Item = Result<Coords, E>> + Unpin,
661{
662 type Item = Result<Coords, E>;
663
664 fn poll_next(self: Pin<&mut Self>, cxt: &mut Context<'_>) -> Poll<Option<Self::Item>> {
665 let mut this = self.project();
666
667 Poll::Ready(loop {
668 if this.pending_left.is_none() && !this.left.is_terminated() {
669 match ready!(this.left.as_mut().poll_next(cxt)) {
670 Some(Ok(coords)) => {
671 assert_eq!(coords.ndim(), this.shape.len());
672 *this.pending_left = Some(coords)
673 }
674 Some(Err(cause)) => return Poll::Ready(Some(Err(cause))),
675 None => {}
676 }
677 }
678
679 if this.pending_right.is_none() && !this.right.is_terminated() {
680 match ready!(this.right.as_mut().poll_next(cxt)) {
681 Some(Ok(coords)) => {
682 assert_eq!(coords.ndim(), this.shape.len());
683 *this.pending_right = Some(coords)
684 }
685 Some(Err(cause)) => return Poll::Ready(Some(Err(cause))),
686 None => {}
687 }
688 }
689
690 match (&mut *this.pending_left, &mut *this.pending_right) {
691 (Some(l), Some(r)) if l.last() < r.last() => {
692 let (r, r_pending) = r.split_lte(&l.last(), this.shape);
693 *this.pending_right = r_pending;
694
695 if let Some(r) = r {
696 create_or_append(this.buffer, r);
697 }
698
699 let mut l = None;
700 mem::swap(this.pending_left, &mut l);
701 create_or_append(this.buffer, l.unwrap());
702 }
703 (Some(l), Some(r)) if r.last() < l.last() => {
704 let (l, l_pending) = l.split_lte(&r.last(), this.shape);
705 *this.pending_left = l_pending;
706
707 if let Some(l) = l {
708 create_or_append(this.buffer, l);
709 }
710
711 let mut r = None;
712 mem::swap(this.pending_right, &mut r);
713 create_or_append(this.buffer, r.unwrap());
714 }
715 (Some(l), Some(r)) => {
716 assert_eq!(l.last(), r.last());
717
718 let mut l = None;
719 mem::swap(this.pending_left, &mut l);
720 create_or_append(this.buffer, l.unwrap());
721
722 let mut r = None;
723 mem::swap(this.pending_right, &mut r);
724 create_or_append(this.buffer, r.unwrap());
725 }
726 (Some(_), None) => {
727 let mut new_l = None;
728 mem::swap(this.pending_left, &mut new_l);
729 create_or_append(this.buffer, new_l.unwrap());
730 }
731 (_, Some(_)) => {
732 let mut new_r = None;
733 mem::swap(this.pending_right, &mut new_r);
734 create_or_append(this.buffer, new_r.unwrap());
735 }
736 (None, None) if this.buffer.is_some() => {
737 let coords = this.buffer.as_ref().unwrap().sorted();
738 *this.buffer = None;
739 break Some(Ok(coords));
740 }
741 (None, None) => break None,
742 }
743
744 if let Some(buffer) = this.buffer {
745 if buffer.len() == *this.block_size {
746 let mut coords = None;
747 mem::swap(&mut coords, this.buffer);
748 break Some(Ok(coords.unwrap().sorted()));
749 } else if buffer.len() > *this.block_size {
750 let coords = buffer.sorted();
751 let (coords, buffer) = coords.split(*this.block_size);
752 *this.buffer = Some(buffer);
753 break Some(Ok(coords));
754 }
755 }
756 })
757 }
758}
759
760#[pin_project]
764pub struct CoordUnique<S> {
765 #[pin]
766 source: Fuse<S>,
767 buffer: Option<Coords>,
768 shape: Vec<u64>,
769 block_size: usize,
770}
771
772impl<S: Stream> CoordUnique<S> {
773 pub fn new(source: S, shape: Vec<u64>, block_size: usize) -> Self {
775 Self {
776 source: source.fuse(),
777 buffer: None,
778 shape,
779 block_size,
780 }
781 }
782}
783
784impl<E, S: Stream<Item = Result<Coords, E>>> Stream for CoordUnique<S> {
785 type Item = Result<Coords, E>;
786
787 fn poll_next(self: Pin<&mut Self>, cxt: &mut Context<'_>) -> Poll<Option<Self::Item>> {
788 let mut this = self.project();
789
790 Poll::Ready(loop {
791 match ready!(this.source.as_mut().poll_next(cxt)) {
792 Some(Ok(block)) => {
793 let buffer = if let Some(buffer) = this.buffer {
794 buffer.append(&block).unique(this.shape)
795 } else {
796 block.unique(this.shape)
797 };
798
799 *this.buffer = Some(buffer);
800 }
801 Some(Err(cause)) => break Some(Err(cause)),
802 None if this.buffer.is_some() => {
803 let mut buffer = None;
804 mem::swap(this.buffer, &mut buffer);
805 break buffer.map(Ok);
806 }
807 None => break None,
808 }
809
810 if let Some(buffer) = this.buffer {
811 if buffer.len() > *this.block_size {
812 let (block, buffer) = buffer.split(*this.block_size);
813 *this.buffer = Some(buffer);
814 break Some(Ok(block));
815 }
816 }
817 })
818 }
819}
820
821#[inline]
822fn create_or_append(coords: &mut Option<Coords>, to_append: Coords) {
823 if to_append.is_empty() {
824 return;
825 }
826
827 assert!(to_append.dims()[0] > 0);
828
829 *coords = match coords {
830 Some(coords) => Some(coords.append(&to_append)),
831 None => Some(to_append),
832 };
833}
834
835#[inline]
836pub fn coord_to_offset(coord: &[u64], coord_bounds: &[u64]) -> u64 {
838 coord_bounds
839 .iter()
840 .zip(coord.iter())
841 .map(|(d, x)| d * x)
842 .sum()
843}
844
845fn index_get(subject: &af::Array<u64>, index: &[u64]) -> af::Array<u64> {
846 let len = subject.dims()[1];
847 let index = af::Array::new(index, af::dim4!(index.len() as u64));
848 let seq4gen = af::seq!(0, (len - 1) as i32, 1);
849 let mut indexer = af::Indexer::default();
850 indexer.set_index(&index, 0, None);
851 indexer.set_index(&seq4gen, 1, Some(true));
852
853 af::index_gen(subject, indexer)
854}
855
856fn index_set(subject: &mut af::Array<u64>, index: &[u64], value: &af::Array<u64>) {
857 debug_assert!(value.dims()[0] == index.len() as u64);
858 debug_assert!(value.dims()[1] == subject.dims()[1]);
859
860 let len = subject.dims()[1];
861 let index = af::Array::new(index, af::dim4!(index.len() as u64));
862 if len == 1 {
863 let mut indexer = af::Indexer::default();
864 indexer.set_index(&index, 0, Some(false));
865 af::assign_gen(subject, &indexer, value);
866 } else {
867 let seq4gen = af::seq!(0, (len - 1) as i32, 1);
868 let mut indexer = af::Indexer::default();
869 indexer.set_index(&index, 0, None);
870 indexer.set_index(&seq4gen, 1, Some(true));
871
872 af::assign_gen(subject, &indexer, value);
873 }
874}
875
876#[cfg(test)]
877mod tests {
878 use super::*;
879
880 #[test]
881 fn test_to_coords() {
882 let offsets = ArrayExt::range(0, 5);
883 let coords = Coords::from_offsets(offsets, &[5, 2]);
884 assert_eq!(
885 coords.into_vec(),
886 vec![vec![0, 0], vec![0, 1], vec![1, 0], vec![1, 1], vec![2, 0],]
887 )
888 }
889
890 #[test]
891 fn test_merge_helpers() {
892 let coord_vec = vec![
893 vec![0, 0, 0],
894 vec![0, 0, 1],
895 vec![0, 1, 0],
896 vec![1, 0, 0],
897 vec![1, 1, 1],
898 ];
899 let coords = Coords::from_iter(coord_vec.to_vec(), 3);
900
901 assert_eq!(&coords.last(), coord_vec.last().unwrap());
902
903 let (l, r) = coords.split(1);
904 assert_eq!(l.to_vec(), &coord_vec[..1]);
905 assert_eq!(r.to_vec(), &coord_vec[1..]);
906
907 let (l, r) = coords.split_lte(&[0, 1, 0], &[2, 2, 2]);
908 assert_eq!(l.as_ref().expect("left").to_vec(), &coord_vec[..3]);
909 assert_eq!(r.as_ref().expect("right").to_vec(), &coord_vec[3..]);
910
911 let joined = l.expect("left").append(r.as_ref().expect("right"));
912 assert_eq!(joined.to_vec(), coords.to_vec());
913
914 assert_eq!(coords.to_vec(), coords.sorted().to_vec());
915 }
916
917 #[test]
918 fn test_unique_helpers() {
919 let coord_vec = vec![
920 vec![0, 0, 0],
921 vec![0, 0, 1],
922 vec![0, 0, 1],
923 vec![0, 1, 0],
924 vec![1, 0, 0],
925 ];
926
927 let coords = Coords::from_iter(coord_vec.to_vec(), 3);
928
929 let expected = vec![vec![0, 0, 0], vec![0, 0, 1], vec![0, 1, 0], vec![1, 0, 0]];
930 assert_eq!(coords.unique(&[2, 2, 2]).to_vec(), expected);
931 }
932
933 #[test]
934 fn test_get_and_set() {
935 let source = Coords::from_iter(vec![vec![0, 1, 2], vec![3, 4, 5], vec![6, 7, 8]], 3);
936
937 let value = source.get(&[1, 2]);
938
939 assert_eq!(value.ndim(), 2);
940 assert_eq!(value.to_vec(), vec![vec![1, 2], vec![4, 5], vec![7, 8]]);
941
942 let mut dest = Coords::empty(&[10, 15, 20], 3);
943 dest.set(&[0, 2], &value);
944
945 assert_eq!(dest.to_vec(), vec![[1, 0, 2], [4, 0, 5], [7, 0, 8],])
946 }
947
948 #[test]
949 fn test_unbroadcast() {
950 let coords = Coords::from_iter(vec![vec![8, 15, 2, 1, 10, 3], vec![9, 16, 3, 4, 11, 6]], 6);
951 let actual = coords.unbroadcast(&[5, 1, 1, 10], &[true, true, false, true, true, false]);
952 assert_eq!(actual.to_vec(), vec![vec![2, 0, 0, 3], vec![3, 0, 0, 6]]);
953 }
954
955 #[test]
956 fn test_reduce() {
957 let coords = Coords::from_iter(vec![vec![0, 1], vec![1, 2]], 2);
958 let actual = coords.expand(&[2, 3, 5], 0);
959 assert_eq!(
960 actual.to_vec(),
961 vec![vec![0, 0, 1], vec![1, 0, 1], vec![0, 1, 2], vec![1, 1, 2],]
962 );
963
964 let coords = Coords::from_iter(vec![vec![0, 1], vec![1, 2]], 2);
965 let actual = coords.expand(&[3, 2, 5], 1);
966 assert_eq!(
967 actual.to_vec(),
968 vec![vec![0, 0, 1], vec![0, 1, 1], vec![1, 0, 2], vec![1, 1, 2],]
969 );
970
971 let coords = Coords::from_iter(vec![vec![0, 1], vec![1, 2]], 2);
972 let actual = coords.expand(&[3, 5, 2], 2);
973 assert_eq!(
974 actual.to_vec(),
975 vec![vec![0, 1, 0], vec![0, 1, 1], vec![1, 2, 0], vec![1, 2, 1],]
976 );
977 }
978}