1use ferray_core::error::{FerrayError, FerrayResult};
4use ferray_core::{Array, Dimension, Element, Ix1, IxDyn};
5
6use crate::reductions::{
7 borrow_data, make_result, output_shape, reduce_axis_general_u64, validate_axis,
8};
9
10#[derive(Debug)]
16pub struct UniqueResult<T: Element> {
17 pub values: Array<T, Ix1>,
19 pub indices: Option<Array<u64, Ix1>>,
22 pub inverse: Option<Array<u64, Ix1>>,
26 pub counts: Option<Array<u64, Ix1>>,
28}
29
30pub fn unique<T, D>(
41 a: &Array<T, D>,
42 return_index: bool,
43 return_inverse: bool,
44 return_counts: bool,
45) -> FerrayResult<UniqueResult<T>>
46where
47 T: Element + PartialOrd + Copy,
48 D: Dimension,
49{
50 let data: Vec<T> = a.iter().copied().collect();
51 let n_data = data.len();
52
53 let mut pairs: Vec<(T, usize)> = data
55 .iter()
56 .copied()
57 .enumerate()
58 .map(|(i, v)| (v, i))
59 .collect();
60 pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
61
62 let mut unique_vals = Vec::new();
69 let mut unique_indices: Vec<u64> = Vec::new();
70 let mut unique_counts: Vec<u64> = Vec::new();
71 let mut inverse_vec: Vec<u64> = if return_inverse {
72 vec![0u64; n_data]
73 } else {
74 Vec::new()
75 };
76
77 if !pairs.is_empty() {
78 unique_vals.push(pairs[0].0);
79 unique_indices.push(pairs[0].1 as u64);
80 if return_inverse {
81 inverse_vec[pairs[0].1] = 0;
82 }
83 let mut count = 1u64;
84 let mut unique_pos: u64 = 0;
85
86 for i in 1..pairs.len() {
87 if pairs[i].0.partial_cmp(&pairs[i - 1].0) == Some(std::cmp::Ordering::Equal) {
88 count += 1;
89 let last = unique_indices.len() - 1;
91 let new_idx = pairs[i].1 as u64;
92 if new_idx < unique_indices[last] {
93 unique_indices[last] = new_idx;
94 }
95 } else {
96 if return_counts {
97 unique_counts.push(count);
98 }
99 unique_vals.push(pairs[i].0);
100 unique_indices.push(pairs[i].1 as u64);
101 count = 1;
102 unique_pos += 1;
103 }
104 if return_inverse {
105 inverse_vec[pairs[i].1] = unique_pos;
106 }
107 }
108 if return_counts {
109 unique_counts.push(count);
110 }
111 }
112
113 let n = unique_vals.len();
114 let values = Array::from_vec(Ix1::new([n]), unique_vals)?;
115 let indices = if return_index {
116 Some(Array::from_vec(Ix1::new([n]), unique_indices)?)
117 } else {
118 None
119 };
120 let inverse = if return_inverse {
121 Some(Array::from_vec(Ix1::new([n_data]), inverse_vec)?)
122 } else {
123 None
124 };
125 let counts = if return_counts {
126 Some(Array::from_vec(Ix1::new([n]), unique_counts)?)
127 } else {
128 None
129 };
130
131 Ok(UniqueResult {
132 values,
133 indices,
134 inverse,
135 counts,
136 })
137}
138
139pub fn nonzero<T, D>(a: &Array<T, D>) -> FerrayResult<Vec<Array<u64, Ix1>>>
150where
151 T: Element + PartialEq + Copy,
152 D: Dimension,
153{
154 let shape = a.shape();
155 let ndim = shape.len();
156 let zero = <T as Element>::zero();
157
158 let mut indices_per_dim: Vec<Vec<u64>> = vec![Vec::new(); ndim];
160
161 let mut strides = vec![1usize; ndim];
163 for i in (0..ndim.saturating_sub(1)).rev() {
164 strides[i] = strides[i + 1] * shape[i + 1];
165 }
166
167 for (flat_idx, &val) in a.iter().enumerate() {
168 if val != zero {
169 let mut rem = flat_idx;
170 for d in 0..ndim {
171 indices_per_dim[d].push((rem / strides[d]) as u64);
172 rem %= strides[d];
173 }
174 }
175 }
176
177 let mut result = Vec::with_capacity(ndim);
178 for idx_vec in indices_per_dim {
179 let n = idx_vec.len();
180 result.push(Array::from_vec(Ix1::new([n]), idx_vec)?);
181 }
182
183 Ok(result)
184}
185
186pub fn where_<T, D>(
199 condition: &Array<bool, D>,
200 x: &Array<T, D>,
201 y: &Array<T, D>,
202) -> FerrayResult<Array<T, D>>
203where
204 T: Element + Copy,
205 D: Dimension,
206{
207 if condition.shape() != x.shape() || condition.shape() != y.shape() {
208 return Err(FerrayError::shape_mismatch(format!(
209 "condition, x, y shapes must match: {:?}, {:?}, {:?}",
210 condition.shape(),
211 x.shape(),
212 y.shape()
213 )));
214 }
215
216 let result: Vec<T> = condition
217 .iter()
218 .zip(x.iter())
219 .zip(y.iter())
220 .map(|((&c, &xv), &yv)| if c { xv } else { yv })
221 .collect();
222
223 Array::from_vec(condition.dim().clone(), result)
224}
225
226pub fn where_condition<D: Dimension>(
233 condition: &Array<bool, D>,
234) -> FerrayResult<Vec<Array<u64, Ix1>>> {
235 let shape = condition.shape();
236 let ndim = shape.len();
237 let mut indices_per_dim: Vec<Vec<u64>> = vec![Vec::new(); ndim];
238
239 let mut strides = vec![1usize; ndim];
240 for i in (0..ndim.saturating_sub(1)).rev() {
241 strides[i] = strides[i + 1] * shape[i + 1];
242 }
243
244 for (flat_idx, &val) in condition.iter().enumerate() {
245 if val {
246 let mut rem = flat_idx;
247 for d in 0..ndim {
248 indices_per_dim[d].push((rem / strides[d]) as u64);
249 rem %= strides[d];
250 }
251 }
252 }
253
254 indices_per_dim
255 .into_iter()
256 .map(|v| {
257 let n = v.len();
258 Array::from_vec(Ix1::new([n]), v)
259 })
260 .collect()
261}
262
263pub fn count_nonzero<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
271where
272 T: Element + PartialEq + Copy,
273 D: Dimension,
274{
275 let zero = <T as Element>::zero();
276 let data = borrow_data(a);
277 match axis {
278 None => {
279 let count = data.iter().filter(|&&x| x != zero).count() as u64;
280 make_result(&[], vec![count])
281 }
282 Some(ax) => {
283 validate_axis(ax, a.ndim())?;
284 let shape = a.shape();
285 let out_s = output_shape(shape, ax);
286 let result = reduce_axis_general_u64(&data, shape, ax, |lane| {
287 lane.iter().filter(|&&x| x != zero).count() as u64
288 });
289 make_result(&out_s, result)
290 }
291 }
292}
293
294pub fn unique_values<T, D>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>>
303where
304 T: Element + PartialOrd + Copy,
305 D: Dimension,
306{
307 Ok(unique(a, false, false, false)?.values)
308}
309
310pub fn unique_counts<T, D>(a: &Array<T, D>) -> FerrayResult<(Array<T, Ix1>, Array<u64, Ix1>)>
314where
315 T: Element + PartialOrd + Copy,
316 D: Dimension,
317{
318 let r = unique(a, false, false, true)?;
319 Ok((r.values, r.counts.expect("return_counts requested")))
320}
321
322pub fn unique_inverse<T, D>(a: &Array<T, D>) -> FerrayResult<(Array<T, Ix1>, Array<u64, Ix1>)>
327where
328 T: Element + PartialOrd + Copy,
329 D: Dimension,
330{
331 let r = unique(a, false, true, false)?;
332 Ok((r.values, r.inverse.expect("return_inverse requested")))
333}
334
335#[allow(clippy::type_complexity)]
342pub fn unique_all<T, D>(
343 a: &Array<T, D>,
344) -> FerrayResult<(
345 Array<T, Ix1>,
346 Array<u64, Ix1>,
347 Array<u64, Ix1>,
348 Array<u64, Ix1>,
349)>
350where
351 T: Element + PartialOrd + Copy,
352 D: Dimension,
353{
354 let r = unique(a, true, true, true)?;
355 Ok((
356 r.values,
357 r.indices.expect("return_index requested"),
358 r.inverse.expect("return_inverse requested"),
359 r.counts.expect("return_counts requested"),
360 ))
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use ferray_core::{Ix1, Ix2};
367
368 #[test]
369 fn test_unique_basic() {
370 let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
371 let u = unique(&a, false, false, false).unwrap();
372 let data: Vec<i32> = u.values.iter().copied().collect();
373 assert_eq!(data, vec![1, 2, 3]);
374 }
375
376 #[test]
377 fn test_unique_with_counts() {
378 let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
379 let u = unique(&a, false, false, true).unwrap();
380 let vals: Vec<i32> = u.values.iter().copied().collect();
381 let cnts: Vec<u64> = u.counts.unwrap().iter().copied().collect();
382 assert_eq!(vals, vec![1, 2, 3]);
383 assert_eq!(cnts, vec![2, 2, 2]);
384 }
385
386 #[test]
387 fn test_unique_with_index() {
388 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![5, 3, 3, 1, 5]).unwrap();
389 let u = unique(&a, true, false, false).unwrap();
390 let vals: Vec<i32> = u.values.iter().copied().collect();
391 let idxs: Vec<u64> = u.indices.unwrap().iter().copied().collect();
392 assert_eq!(vals, vec![1, 3, 5]);
393 assert_eq!(idxs, vec![3, 1, 0]);
394 }
395
396 #[test]
399 fn test_unique_inverse_reconstructs_input() {
400 let input = vec![3, 1, 2, 1, 3, 2];
402 let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), input.clone()).unwrap();
403 let u = unique(&a, false, true, false).unwrap();
404 let vals: Vec<i32> = u.values.iter().copied().collect();
405 let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
406 assert_eq!(vals, vec![1, 2, 3]);
408 let reconstructed: Vec<i32> = inv.iter().map(|&i| vals[i as usize]).collect();
410 assert_eq!(reconstructed, input);
411 }
412
413 #[test]
414 fn test_unique_inverse_all_together() {
415 let a = Array::<i32, Ix1>::from_vec(Ix1::new([7]), vec![2, 1, 2, 3, 1, 2, 3]).unwrap();
418 let u = unique(&a, true, true, true).unwrap();
419 let vals: Vec<i32> = u.values.iter().copied().collect();
420 let idxs: Vec<u64> = u.indices.unwrap().iter().copied().collect();
421 let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
422 let cnts: Vec<u64> = u.counts.unwrap().iter().copied().collect();
423 assert_eq!(vals, vec![1, 2, 3]);
424 assert_eq!(idxs, vec![1, 0, 3]); assert_eq!(cnts, vec![2, 3, 2]);
426 let reconstructed: Vec<i32> = inv.iter().map(|&i| vals[i as usize]).collect();
428 assert_eq!(reconstructed, vec![2, 1, 2, 3, 1, 2, 3]);
429 }
430
431 #[test]
432 fn test_unique_inverse_with_2d_flattens_first() {
433 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 1, 3, 2, 1]).unwrap();
436 let u = unique(&a, false, true, false).unwrap();
437 let vals: Vec<i32> = u.values.iter().copied().collect();
438 let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
439 assert_eq!(vals, vec![1, 2, 3]);
440 assert_eq!(inv.len(), 6);
441 let flat: Vec<i32> = vec![1, 2, 1, 3, 2, 1];
442 let reconstructed: Vec<i32> = inv.iter().map(|&i| vals[i as usize]).collect();
443 assert_eq!(reconstructed, flat);
444 }
445
446 #[test]
447 fn test_unique_inverse_empty_input() {
448 let a = Array::<i32, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
449 let u = unique(&a, false, true, false).unwrap();
450 assert_eq!(u.values.shape(), &[0]);
451 let inv = u.inverse.unwrap();
452 assert_eq!(inv.shape(), &[0]);
453 }
454
455 #[test]
456 fn test_unique_inverse_single_value() {
457 let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![7, 7, 7, 7]).unwrap();
459 let u = unique(&a, false, true, false).unwrap();
460 let inv: Vec<u64> = u.inverse.unwrap().iter().copied().collect();
461 assert_eq!(inv, vec![0, 0, 0, 0]);
462 }
463
464 #[test]
465 fn test_unique_without_inverse_leaves_field_none() {
466 let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 1]).unwrap();
467 let u = unique(&a, false, false, false).unwrap();
468 assert!(u.inverse.is_none());
469 }
470
471 #[test]
472 fn test_nonzero_1d() {
473 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
474 let nz = nonzero(&a).unwrap();
475 assert_eq!(nz.len(), 1);
476 let data: Vec<u64> = nz[0].iter().copied().collect();
477 assert_eq!(data, vec![1, 3]);
478 }
479
480 #[test]
481 fn test_nonzero_2d() {
482 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
483 let nz = nonzero(&a).unwrap();
484 assert_eq!(nz.len(), 2);
485 let rows: Vec<u64> = nz[0].iter().copied().collect();
486 let cols: Vec<u64> = nz[1].iter().copied().collect();
487 assert_eq!(rows, vec![0, 1, 1]);
488 assert_eq!(cols, vec![1, 0, 2]);
489 }
490
491 #[test]
492 fn test_where_basic() {
493 let cond =
494 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
495 let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
496 let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
497 let r = where_(&cond, &x, &y).unwrap();
498 let data: Vec<f64> = r.iter().copied().collect();
499 assert_eq!(data, vec![1.0, 20.0, 3.0, 40.0]);
500 }
501
502 #[test]
503 fn test_where_shape_mismatch() {
504 let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
505 let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
506 let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
507 assert!(where_(&cond, &x, &y).is_err());
508 }
509
510 #[test]
511 fn test_count_nonzero_total() {
512 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
513 let c = count_nonzero(&a, None).unwrap();
514 assert_eq!(c.iter().next(), Some(&2u64));
515 }
516
517 #[test]
518 fn test_count_nonzero_axis() {
519 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
520 let c = count_nonzero(&a, Some(0)).unwrap();
521 let data: Vec<u64> = c.iter().copied().collect();
522 assert_eq!(data, vec![1, 1, 1]);
523 }
524
525 #[test]
526 fn test_count_nonzero_axis1() {
527 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
528 let c = count_nonzero(&a, Some(1)).unwrap();
529 let data: Vec<u64> = c.iter().copied().collect();
530 assert_eq!(data, vec![1, 2]);
531 }
532
533 #[test]
536 fn test_unique_values_alias() {
537 let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
538 let v = unique_values(&a).unwrap();
539 assert_eq!(v.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
540 }
541
542 #[test]
543 fn test_unique_counts_alias() {
544 let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
545 let (v, c) = unique_counts(&a).unwrap();
546 assert_eq!(v.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
547 assert_eq!(c.iter().copied().collect::<Vec<_>>(), vec![2, 2, 2]);
548 }
549
550 #[test]
551 fn test_unique_inverse_alias() {
552 let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
553 let (v, inv) = unique_inverse(&a).unwrap();
554 assert_eq!(v.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
557 assert_eq!(
558 inv.iter().copied().collect::<Vec<_>>(),
559 vec![2, 0, 1, 0, 2, 1]
560 );
561 }
562
563 #[test]
564 fn test_unique_all_alias() {
565 let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
566 let (v, _idx, inv, c) = unique_all(&a).unwrap();
567 assert_eq!(v.iter().copied().collect::<Vec<_>>(), vec![1, 2, 3]);
568 assert_eq!(
569 inv.iter().copied().collect::<Vec<_>>(),
570 vec![2, 0, 1, 0, 2, 1]
571 );
572 assert_eq!(c.iter().copied().collect::<Vec<_>>(), vec![2, 2, 2]);
573 }
574}