1use ferray_core::dimension::{Dimension, IxDyn};
36use ferray_core::error::{FerrayError, FerrayResult};
37use ferray_core::{Array, Ix1};
38
39use crate::MaskedArray;
40
41fn row_major_strides(shape: &[usize]) -> Vec<usize> {
43 let ndim = shape.len();
44 let mut strides = vec![1usize; ndim];
45 for i in (0..ndim.saturating_sub(1)).rev() {
46 strides[i] = strides[i + 1] * shape[i + 1];
47 }
48 strides
49}
50
51pub fn ma_where<D>(
65 condition: &MaskedArray<bool, D>,
66 x: &MaskedArray<f64, D>,
67 y: &MaskedArray<f64, D>,
68) -> FerrayResult<MaskedArray<f64, D>>
69where
70 D: Dimension,
71{
72 let shape = condition.shape();
73 if x.shape() != shape || y.shape() != shape {
74 return Err(FerrayError::shape_mismatch(format!(
75 "ma_where: condition {:?}, x {:?}, y {:?} must share one shape",
76 shape,
77 x.shape(),
78 y.shape()
79 )));
80 }
81 let cd: Vec<bool> = condition.data().iter().copied().collect();
82 let cm: Vec<bool> = condition.mask().iter().copied().collect();
83 let xd: Vec<f64> = x.data().iter().copied().collect();
84 let xm: Vec<bool> = x.mask().iter().copied().collect();
85 let yd: Vec<f64> = y.data().iter().copied().collect();
86 let ym: Vec<bool> = y.mask().iter().copied().collect();
87
88 let n = cd.len();
89 let mut out_data = Vec::with_capacity(n);
90 let mut out_mask = Vec::with_capacity(n);
91 for i in 0..n {
92 let cf = cd[i] && !cm[i];
94 if cf {
95 out_data.push(xd[i]);
96 out_mask.push(xm[i] || cm[i]);
97 } else {
98 out_data.push(yd[i]);
99 out_mask.push(ym[i] || cm[i]);
100 }
101 }
102 let data_arr = Array::from_vec(condition.data().dim().clone(), out_data)?;
103 let mask_arr = Array::from_vec(condition.data().dim().clone(), out_mask)?;
104 MaskedArray::new(data_arr, mask_arr)
105}
106
107pub fn ma_choose<D>(
120 indices: &MaskedArray<f64, D>,
121 choices: &[MaskedArray<f64, D>],
122) -> FerrayResult<MaskedArray<f64, D>>
123where
124 D: Dimension,
125{
126 if choices.is_empty() {
127 return Err(FerrayError::invalid_value(
128 "ma_choose: choices must be a non-empty sequence",
129 ));
130 }
131 let shape = indices.shape();
132 for (k, c) in choices.iter().enumerate() {
133 if c.shape() != shape {
134 return Err(FerrayError::shape_mismatch(format!(
135 "ma_choose: choices[{k}] shape {:?} != indices shape {:?}",
136 c.shape(),
137 shape
138 )));
139 }
140 }
141 let idx_d: Vec<f64> = indices.data().iter().copied().collect();
142 let idx_m: Vec<bool> = indices.mask().iter().copied().collect();
143 let ch_d: Vec<Vec<f64>> = choices
144 .iter()
145 .map(|c| c.data().iter().copied().collect())
146 .collect();
147 let ch_m: Vec<Vec<bool>> = choices
148 .iter()
149 .map(|c| c.mask().iter().copied().collect())
150 .collect();
151
152 let n = idx_d.len();
153 let nchoices = choices.len();
154 let mut out_data = Vec::with_capacity(n);
155 let mut out_mask = Vec::with_capacity(n);
156 for i in 0..n {
157 let raw = if idx_m[i] { 0.0 } else { idx_d[i] };
159 if !(raw.is_finite() && raw.fract() == 0.0 && raw >= 0.0) {
160 return Err(FerrayError::invalid_value(format!(
161 "ma_choose: index {raw} is not a non-negative integer"
162 )));
163 }
164 let k = raw as usize;
165 if k >= nchoices {
166 return Err(FerrayError::index_out_of_bounds(k as isize, 0, nchoices));
167 }
168 out_data.push(ch_d[k][i]);
169 out_mask.push(ch_m[k][i] || idx_m[i]);
170 }
171 let data_arr = Array::from_vec(indices.data().dim().clone(), out_data)?;
172 let mask_arr = Array::from_vec(indices.data().dim().clone(), out_mask)?;
173 MaskedArray::new(data_arr, mask_arr)
174}
175
176fn diff_once(
179 data: &[f64],
180 mask: &[bool],
181 shape: &[usize],
182 axis: usize,
183) -> FerrayResult<(Vec<f64>, Vec<bool>, Vec<usize>)> {
184 let ndim = shape.len();
185 let strides = row_major_strides(shape);
186 let lane = shape[axis];
187
188 let mut out_shape = shape.to_vec();
189 out_shape[axis] = lane.saturating_sub(1);
191 let out_size: usize = out_shape.iter().product();
192
193 let out_strides = row_major_strides(&out_shape);
194 let mut out_data = vec![0.0f64; out_size];
195 let mut out_mask = vec![false; out_size];
196
197 let mut multi = vec![0usize; ndim];
199 for _ in 0..out_size {
200 let out_flat: usize = multi
201 .iter()
202 .zip(out_strides.iter())
203 .map(|(i, s)| i * s)
204 .sum();
205 let lo_flat: usize = multi.iter().zip(strides.iter()).map(|(i, s)| i * s).sum();
207 let hi_flat = lo_flat + strides[axis];
208 out_data[out_flat] = data[hi_flat] - data[lo_flat];
209 out_mask[out_flat] = mask[lo_flat] || mask[hi_flat];
210
211 for d in (0..ndim).rev() {
213 multi[d] += 1;
214 if multi[d] < out_shape[d] {
215 break;
216 }
217 multi[d] = 0;
218 }
219 }
220 Ok((out_data, out_mask, out_shape))
221}
222
223pub fn ma_diff(
235 a: &MaskedArray<f64, IxDyn>,
236 n: usize,
237 axis: isize,
238) -> FerrayResult<MaskedArray<f64, IxDyn>> {
239 if a.ndim() == 0 {
240 return Err(FerrayError::invalid_value(
241 "ma_diff: input must be at least one dimensional",
242 ));
243 }
244 let ndim = a.ndim();
245 let axis_u = if axis < 0 {
246 let adj = axis + ndim as isize;
247 if adj < 0 {
248 return Err(FerrayError::axis_out_of_bounds(axis.unsigned_abs(), ndim));
249 }
250 adj as usize
251 } else {
252 axis as usize
253 };
254 if axis_u >= ndim {
255 return Err(FerrayError::axis_out_of_bounds(axis_u, ndim));
256 }
257 if n == 0 {
258 return Ok(a.clone());
259 }
260
261 let mut data: Vec<f64> = a.data().iter().copied().collect();
262 let mut mask: Vec<bool> = a.mask().iter().copied().collect();
263 let mut shape: Vec<usize> = a.shape().to_vec();
264
265 for _ in 0..n {
266 if shape[axis_u] == 0 {
267 break;
268 }
269 let (d, m, s) = diff_once(&data, &mask, &shape, axis_u)?;
270 data = d;
271 mask = m;
272 shape = s;
273 }
274
275 let data_arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&shape), data)?;
276 let mask_arr = Array::<bool, IxDyn>::from_vec(IxDyn::new(&shape), mask)?;
277 MaskedArray::new(data_arr, mask_arr)
278}
279
280pub fn ma_ediff1d<D>(
292 ary: &MaskedArray<f64, D>,
293 to_begin: Option<&[f64]>,
294 to_end: Option<&[f64]>,
295) -> FerrayResult<MaskedArray<f64, Ix1>>
296where
297 D: Dimension,
298{
299 let flat_d: Vec<f64> = ary.data().iter().copied().collect();
300 let flat_m: Vec<bool> = ary.mask().iter().copied().collect();
301
302 let mut data: Vec<f64> = Vec::new();
303 let mut mask: Vec<bool> = Vec::new();
304
305 if let Some(begin) = to_begin {
306 for &v in begin {
307 data.push(v);
308 mask.push(false);
309 }
310 }
311 if flat_d.len() >= 2 {
312 for i in 0..flat_d.len() - 1 {
313 data.push(flat_d[i + 1] - flat_d[i]);
314 mask.push(flat_m[i] || flat_m[i + 1]);
315 }
316 }
317 if let Some(end) = to_end {
318 for &v in end {
319 data.push(v);
320 mask.push(false);
321 }
322 }
323
324 let len = data.len();
325 let data_arr = Array::<f64, Ix1>::from_vec(Ix1::new([len]), data)?;
326 let mask_arr = Array::<bool, Ix1>::from_vec(Ix1::new([len]), mask)?;
327 MaskedArray::new(data_arr, mask_arr)
328}
329
330pub fn ma_nonzero<D>(a: &MaskedArray<f64, D>) -> FerrayResult<Vec<Array<i64, Ix1>>>
341where
342 D: Dimension,
343{
344 let shape = a.shape().to_vec();
345 let ndim = shape.len().max(1);
346 let data: Vec<f64> = a.data().iter().copied().collect();
347 let mask: Vec<bool> = a.mask().iter().copied().collect();
348 let strides = row_major_strides(&shape);
349
350 let mut coords: Vec<Vec<i64>> = vec![Vec::new(); ndim];
351 for (flat, (&v, &m)) in data.iter().zip(mask.iter()).enumerate() {
352 if m || v == 0.0 {
354 continue;
355 }
356 if shape.is_empty() {
357 coords[0].push(0);
359 continue;
360 }
361 let mut rem = flat;
362 for d in 0..shape.len() {
363 let c = rem / strides[d];
364 rem %= strides[d];
365 coords[d].push(c as i64);
366 }
367 }
368
369 let mut out = Vec::with_capacity(ndim);
370 for axis_coords in coords {
371 let len = axis_coords.len();
372 out.push(Array::<i64, Ix1>::from_vec(Ix1::new([len]), axis_coords)?);
373 }
374 Ok(out)
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use ferray_core::Array;
381
382 fn ma1(data: &[f64], mask: &[bool]) -> MaskedArray<f64, Ix1> {
383 let n = data.len();
384 let d = Array::<f64, Ix1>::from_vec(Ix1::new([n]), data.to_vec()).unwrap();
385 let m = Array::<bool, Ix1>::from_vec(Ix1::new([n]), mask.to_vec()).unwrap();
386 MaskedArray::new(d, m).unwrap()
387 }
388
389 fn mb1(data: &[bool], mask: &[bool]) -> MaskedArray<bool, Ix1> {
390 let n = data.len();
391 let d = Array::<bool, Ix1>::from_vec(Ix1::new([n]), data.to_vec()).unwrap();
392 let m = Array::<bool, Ix1>::from_vec(Ix1::new([n]), mask.to_vec()).unwrap();
393 MaskedArray::new(d, m).unwrap()
394 }
395
396 fn dyn_ma(data: &[f64], mask: &[bool], shape: &[usize]) -> MaskedArray<f64, IxDyn> {
397 let d = Array::<f64, IxDyn>::from_vec(IxDyn::new(shape), data.to_vec()).unwrap();
398 let m = Array::<bool, IxDyn>::from_vec(IxDyn::new(shape), mask.to_vec()).unwrap();
399 MaskedArray::new(d, m).unwrap()
400 }
401
402 #[test]
406 fn where_matches_numpy_scalar_branches() {
407 let cond = mb1(&[true, false, true], &[false, true, false]);
408 let x = ma1(&[10.0, 10.0, 10.0], &[false, false, false]);
409 let y = ma1(&[20.0, 20.0, 20.0], &[false, false, false]);
410 let out = ma_where(&cond, &x, &y).unwrap();
411 assert_eq!(
412 out.data().iter().copied().collect::<Vec<_>>(),
413 vec![10.0, 20.0, 10.0]
414 );
415 assert_eq!(
416 out.mask().iter().copied().collect::<Vec<_>>(),
417 vec![false, true, false]
418 );
419 }
420
421 #[test]
424 fn where_propagates_source_mask() {
425 let cond = mb1(&[true, true, false], &[false, false, false]);
426 let x = ma1(&[1.0, 2.0, 3.0], &[false, true, false]);
427 let y = ma1(&[4.0, 5.0, 6.0], &[true, false, true]);
428 let out = ma_where(&cond, &x, &y).unwrap();
429 assert_eq!(
430 out.data().iter().copied().collect::<Vec<_>>(),
431 vec![1.0, 2.0, 6.0]
432 );
433 assert_eq!(
434 out.mask().iter().copied().collect::<Vec<_>>(),
435 vec![false, true, true]
436 );
437 }
438
439 #[test]
442 fn choose_matches_numpy() {
443 let idx = ma1(&[0.0, 1.0, 0.0], &[false, true, false]);
444 let c0 = ma1(&[10.0, 20.0, 30.0], &[false, false, false]);
445 let c1 = ma1(&[40.0, 50.0, 60.0], &[false, false, false]);
446 let out = ma_choose(&idx, &[c0, c1]).unwrap();
447 assert_eq!(
448 out.data().iter().copied().collect::<Vec<_>>(),
449 vec![10.0, 20.0, 30.0]
450 );
451 assert_eq!(
452 out.mask().iter().copied().collect::<Vec<_>>(),
453 vec![false, true, false]
454 );
455 }
456
457 #[test]
460 fn choose_propagates_choice_mask() {
461 let idx = ma1(&[0.0, 1.0, 0.0], &[false, false, false]);
462 let c0 = ma1(&[10.0, 20.0, 30.0], &[false, false, true]);
463 let c1 = ma1(&[40.0, 50.0, 60.0], &[false, false, false]);
464 let out = ma_choose(&idx, &[c0, c1]).unwrap();
465 assert_eq!(
466 out.data().iter().copied().collect::<Vec<_>>(),
467 vec![10.0, 50.0, 30.0]
468 );
469 assert_eq!(
470 out.mask().iter().copied().collect::<Vec<_>>(),
471 vec![false, false, true]
472 );
473 }
474
475 #[test]
478 fn diff_n1_matches_numpy() {
479 let data = [1.0, 2.0, 3.0, 4.0, 7.0, 0.0, 2.0, 3.0];
480 let mask = [true, false, false, false, false, true, false, false];
481 let m = dyn_ma(&data, &mask, &[8]);
482 let out = ma_diff(&m, 1, -1).unwrap();
483 assert_eq!(
484 out.mask().iter().copied().collect::<Vec<_>>(),
485 vec![true, false, false, false, true, true, false]
486 );
487 let d: Vec<f64> = out.data().iter().copied().collect();
489 assert_eq!(d[1], 1.0);
490 assert_eq!(d[6], 1.0);
491 }
492
493 #[test]
495 fn diff_n2_matches_numpy() {
496 let data = [1.0, 2.0, 3.0, 4.0, 7.0, 0.0, 2.0, 3.0];
497 let mask = [true, false, false, false, false, true, false, false];
498 let m = dyn_ma(&data, &mask, &[8]);
499 let out = ma_diff(&m, 2, -1).unwrap();
500 assert_eq!(
501 out.mask().iter().copied().collect::<Vec<_>>(),
502 vec![true, false, false, true, true, true]
503 );
504 }
505
506 #[test]
509 fn diff_axis0_2d_matches_numpy() {
510 let data = [1.0, 3.0, 1.0, 5.0, 10.0, 0.0, 1.0, 5.0, 6.0, 8.0];
511 let mask = [
512 true, false, true, false, false, false, true, false, false, false,
513 ];
514 let m = dyn_ma(&data, &mask, &[2, 5]);
515 let out = ma_diff(&m, 1, 0).unwrap();
516 assert_eq!(out.shape(), &[1, 5]);
517 assert_eq!(
518 out.mask().iter().copied().collect::<Vec<_>>(),
519 vec![true, true, true, false, false]
520 );
521 }
522
523 #[test]
526 fn ediff1d_matches_numpy() {
527 let m = ma1(&[1.0, 2.0, 3.0, 4.0], &[false, true, false, false]);
528 let out = ma_ediff1d(&m, None, None).unwrap();
529 assert_eq!(
530 out.mask().iter().copied().collect::<Vec<_>>(),
531 vec![true, true, false]
532 );
533 let d: Vec<f64> = out.data().iter().copied().collect();
534 assert_eq!(d[2], 1.0);
535 }
536
537 #[test]
539 fn ediff1d_to_begin_end() {
540 let m = ma1(&[1.0, 2.0, 3.0, 4.0], &[false, true, false, false]);
541 let out = ma_ediff1d(&m, Some(&[99.0]), Some(&[88.0])).unwrap();
542 assert_eq!(
543 out.data().iter().copied().collect::<Vec<_>>(),
544 vec![99.0, 1.0, 1.0, 1.0, 88.0]
545 );
546 assert_eq!(
547 out.mask().iter().copied().collect::<Vec<_>>(),
548 vec![false, true, true, false, false]
549 );
550 }
551
552 #[test]
554 fn nonzero_1d_treats_masked_as_zero() {
555 let m = ma1(&[0.0, 1.0, 0.0, 2.0], &[false, false, true, false]);
556 let out = ma_nonzero(&m).unwrap();
557 assert_eq!(out.len(), 1);
558 assert_eq!(out[0].iter().copied().collect::<Vec<_>>(), vec![1, 3]);
559 }
560
561 #[test]
563 fn nonzero_2d_matches_numpy() {
564 let m = dyn_ma(&[0.0, 1.0, 2.0, 0.0], &[false, false, true, false], &[2, 2]);
565 let out = ma_nonzero(&m).unwrap();
566 assert_eq!(out.len(), 2);
567 assert_eq!(out[0].iter().copied().collect::<Vec<_>>(), vec![0]);
568 assert_eq!(out[1].iter().copied().collect::<Vec<_>>(), vec![1]);
569 }
570}