1use ferray_core::error::{FerrayError, FerrayResult};
4use ferray_core::{Array, Dimension, Element, Ix1, IxDyn};
5
6use crate::reductions::{
7 borrow_data, make_result, output_shape, reduce_axis_general_u64, validate_axis,
8};
9
10#[derive(Debug)]
16pub struct UniqueResult<T: Element> {
17 pub values: Array<T, Ix1>,
19 pub indices: Option<Array<u64, Ix1>>,
22 pub counts: Option<Array<u64, Ix1>>,
24}
25
26pub fn unique<T, D>(
32 a: &Array<T, D>,
33 return_index: bool,
34 return_counts: bool,
35) -> FerrayResult<UniqueResult<T>>
36where
37 T: Element + PartialOrd + Copy,
38 D: Dimension,
39{
40 let data: Vec<T> = a.iter().copied().collect();
41
42 let mut pairs: Vec<(T, usize)> = data
44 .iter()
45 .copied()
46 .enumerate()
47 .map(|(i, v)| (v, i))
48 .collect();
49 pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
50
51 let mut unique_vals = Vec::new();
53 let mut unique_indices: Vec<u64> = Vec::new();
54 let mut unique_counts: Vec<u64> = Vec::new();
55
56 if !pairs.is_empty() {
57 unique_vals.push(pairs[0].0);
58 unique_indices.push(pairs[0].1 as u64);
59 let mut count = 1u64;
60
61 for i in 1..pairs.len() {
62 if pairs[i].0.partial_cmp(&pairs[i - 1].0) != Some(std::cmp::Ordering::Equal) {
63 if return_counts {
64 unique_counts.push(count);
65 }
66 unique_vals.push(pairs[i].0);
67 unique_indices.push(pairs[i].1 as u64);
68 count = 1;
69 } else {
70 count += 1;
71 let last = unique_indices.len() - 1;
73 let new_idx = pairs[i].1 as u64;
74 if new_idx < unique_indices[last] {
75 unique_indices[last] = new_idx;
76 }
77 }
78 }
79 if return_counts {
80 unique_counts.push(count);
81 }
82 }
83
84 let n = unique_vals.len();
85 let values = Array::from_vec(Ix1::new([n]), unique_vals)?;
86 let indices = if return_index {
87 Some(Array::from_vec(Ix1::new([n]), unique_indices)?)
88 } else {
89 None
90 };
91 let counts = if return_counts {
92 Some(Array::from_vec(Ix1::new([n]), unique_counts)?)
93 } else {
94 None
95 };
96
97 Ok(UniqueResult {
98 values,
99 indices,
100 counts,
101 })
102}
103
104pub fn nonzero<T, D>(a: &Array<T, D>) -> FerrayResult<Vec<Array<u64, Ix1>>>
115where
116 T: Element + PartialEq + Copy,
117 D: Dimension,
118{
119 let shape = a.shape();
120 let ndim = shape.len();
121 let zero = <T as Element>::zero();
122
123 let mut indices_per_dim: Vec<Vec<u64>> = vec![Vec::new(); ndim];
125
126 let mut strides = vec![1usize; ndim];
128 for i in (0..ndim.saturating_sub(1)).rev() {
129 strides[i] = strides[i + 1] * shape[i + 1];
130 }
131
132 for (flat_idx, &val) in a.iter().enumerate() {
133 if val != zero {
134 let mut rem = flat_idx;
135 for d in 0..ndim {
136 indices_per_dim[d].push((rem / strides[d]) as u64);
137 rem %= strides[d];
138 }
139 }
140 }
141
142 let mut result = Vec::with_capacity(ndim);
143 for idx_vec in indices_per_dim {
144 let n = idx_vec.len();
145 result.push(Array::from_vec(Ix1::new([n]), idx_vec)?);
146 }
147
148 Ok(result)
149}
150
151pub fn where_<T, D>(
164 condition: &Array<bool, D>,
165 x: &Array<T, D>,
166 y: &Array<T, D>,
167) -> FerrayResult<Array<T, D>>
168where
169 T: Element + Copy,
170 D: Dimension,
171{
172 if condition.shape() != x.shape() || condition.shape() != y.shape() {
173 return Err(FerrayError::shape_mismatch(format!(
174 "condition, x, y shapes must match: {:?}, {:?}, {:?}",
175 condition.shape(),
176 x.shape(),
177 y.shape()
178 )));
179 }
180
181 let result: Vec<T> = condition
182 .iter()
183 .zip(x.iter())
184 .zip(y.iter())
185 .map(|((&c, &xv), &yv)| if c { xv } else { yv })
186 .collect();
187
188 Array::from_vec(condition.dim().clone(), result)
189}
190
191pub fn count_nonzero<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<u64, IxDyn>>
199where
200 T: Element + PartialEq + Copy,
201 D: Dimension,
202{
203 let zero = <T as Element>::zero();
204 let data = borrow_data(a);
205 match axis {
206 None => {
207 let count = data.iter().filter(|&&x| x != zero).count() as u64;
208 make_result(&[], vec![count])
209 }
210 Some(ax) => {
211 validate_axis(ax, a.ndim())?;
212 let shape = a.shape();
213 let out_s = output_shape(shape, ax);
214 let result = reduce_axis_general_u64(&data, shape, ax, |lane| {
215 lane.iter().filter(|&&x| x != zero).count() as u64
216 });
217 make_result(&out_s, result)
218 }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use ferray_core::{Ix1, Ix2};
226
227 #[test]
228 fn test_unique_basic() {
229 let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
230 let u = unique(&a, false, false).unwrap();
231 let data: Vec<i32> = u.values.iter().copied().collect();
232 assert_eq!(data, vec![1, 2, 3]);
233 }
234
235 #[test]
236 fn test_unique_with_counts() {
237 let a = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![3, 1, 2, 1, 3, 2]).unwrap();
238 let u = unique(&a, false, true).unwrap();
239 let vals: Vec<i32> = u.values.iter().copied().collect();
240 let cnts: Vec<u64> = u.counts.unwrap().iter().copied().collect();
241 assert_eq!(vals, vec![1, 2, 3]);
242 assert_eq!(cnts, vec![2, 2, 2]);
243 }
244
245 #[test]
246 fn test_unique_with_index() {
247 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![5, 3, 3, 1, 5]).unwrap();
248 let u = unique(&a, true, false).unwrap();
249 let vals: Vec<i32> = u.values.iter().copied().collect();
250 let idxs: Vec<u64> = u.indices.unwrap().iter().copied().collect();
251 assert_eq!(vals, vec![1, 3, 5]);
252 assert_eq!(idxs, vec![3, 1, 0]);
253 }
254
255 #[test]
256 fn test_nonzero_1d() {
257 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
258 let nz = nonzero(&a).unwrap();
259 assert_eq!(nz.len(), 1);
260 let data: Vec<u64> = nz[0].iter().copied().collect();
261 assert_eq!(data, vec![1, 3]);
262 }
263
264 #[test]
265 fn test_nonzero_2d() {
266 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
267 let nz = nonzero(&a).unwrap();
268 assert_eq!(nz.len(), 2);
269 let rows: Vec<u64> = nz[0].iter().copied().collect();
270 let cols: Vec<u64> = nz[1].iter().copied().collect();
271 assert_eq!(rows, vec![0, 1, 1]);
272 assert_eq!(cols, vec![1, 0, 2]);
273 }
274
275 #[test]
276 fn test_where_basic() {
277 let cond =
278 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
279 let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
280 let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
281 let r = where_(&cond, &x, &y).unwrap();
282 let data: Vec<f64> = r.iter().copied().collect();
283 assert_eq!(data, vec![1.0, 20.0, 3.0, 40.0]);
284 }
285
286 #[test]
287 fn test_where_shape_mismatch() {
288 let cond = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
289 let x = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
290 let y = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
291 assert!(where_(&cond, &x, &y).is_err());
292 }
293
294 #[test]
295 fn test_count_nonzero_total() {
296 let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 0, 3, 0]).unwrap();
297 let c = count_nonzero(&a, None).unwrap();
298 assert_eq!(c.iter().next(), Some(&2u64));
299 }
300
301 #[test]
302 fn test_count_nonzero_axis() {
303 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
304 let c = count_nonzero(&a, Some(0)).unwrap();
305 let data: Vec<u64> = c.iter().copied().collect();
306 assert_eq!(data, vec![1, 1, 1]);
307 }
308
309 #[test]
310 fn test_count_nonzero_axis1() {
311 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0, 1, 0, 3, 0, 5]).unwrap();
312 let c = count_nonzero(&a, Some(1)).unwrap();
313 let data: Vec<u64> = c.iter().copied().collect();
314 assert_eq!(data, vec![1, 2]);
315 }
316}