Skip to main content

ferray_ma/
ufunc_support.rs

1// ferray-ma: Ufunc support wrappers (REQ-12)
2//
3// Wrapper functions that accept MaskedArray and call underlying ferray-ufunc
4// operations on the data, then propagate masks. Masked elements are skipped;
5// their output positions retain the masked array's `fill_value`.
6
7use ferray_core::dimension::Dimension;
8use ferray_core::dtype::Element;
9use ferray_core::error::FerrayResult;
10use num_traits::Float;
11
12use ferray_core::Array;
13use ferray_core::error::FerrayError;
14
15use crate::MaskedArray;
16use crate::arithmetic::{masked_binary_op, masked_unary_op};
17
18// ---------------------------------------------------------------------------
19// Domain-aware ufunc wrappers (#503)
20//
21// NumPy's masked ufuncs have domain checking: `ma.log(masked_array)`
22// automatically masks elements that are out of the log domain
23// (negative values). Similarly `ma.sqrt` masks negative inputs,
24// `ma.divide` masks zero denominators, `ma.arcsin`/`ma.arccos` mask
25// |x| > 1 inputs. The masked positions in the result combine the
26// existing mask with a freshly-computed "domain mask" via OR.
27//
28// The plain `log`/`sqrt`/`arcsin`/… wrappers above call the raw Float
29// method, so an in-domain masked element whose *underlying data* is
30// out of domain silently produces NaN in the result data — annoying
31// for pipelines that assume "masked = bad" but "unmasked = valid".
32// The `_domain` variants fix that by computing a new mask = old_mask
33// || !in_domain(x) and substituting the fill_value at positions
34// newly marked as masked.
35// ---------------------------------------------------------------------------
36
37/// Apply a unary function to a masked array, additionally masking any
38/// position where the underlying data is out of the function's domain.
39///
40/// `in_domain(x)` should return `true` when `x` is a valid input to
41/// `f`. For any unmasked position where `in_domain(x)` returns
42/// `false`, the result mask is set to `true` and the result data
43/// carries the masked array's `fill_value`.
44///
45/// This is the generic helper backing [`log_domain`], [`sqrt_domain`],
46/// [`arcsin_domain`], [`arccos_domain`] (#503).
47pub fn masked_unary_domain<T, D, F, Dom>(
48    ma: &MaskedArray<T, D>,
49    f: F,
50    in_domain: Dom,
51) -> FerrayResult<MaskedArray<T, D>>
52where
53    T: Element + Copy,
54    D: Dimension,
55    F: Fn(T) -> T,
56    Dom: Fn(T) -> bool,
57{
58    let fill = ma.fill_value();
59    let n = ma.size();
60    let mut data_out: Vec<T> = Vec::with_capacity(n);
61    let mut mask_out: Vec<bool> = Vec::with_capacity(n);
62    for (&v, &m) in ma.data().iter().zip(ma.mask().iter()) {
63        // Mask if already masked OR the domain predicate rejects v.
64        let should_mask = m || !in_domain(v);
65        if should_mask {
66            data_out.push(fill);
67            mask_out.push(true);
68        } else {
69            data_out.push(f(v));
70            mask_out.push(false);
71        }
72    }
73    let data_arr = Array::from_vec(ma.dim().clone(), data_out)?;
74    let mask_arr = Array::from_vec(ma.dim().clone(), mask_out)?;
75    let mut out = MaskedArray::new(data_arr, mask_arr)?;
76    out.set_fill_value(fill);
77    Ok(out)
78}
79
80/// Apply a binary function to two masked arrays, additionally masking
81/// any position where `in_domain(a, b)` returns `false`.
82///
83/// The result's mask is `old_a_mask OR old_b_mask OR !in_domain(a, b)`.
84/// Used by [`divide_domain`] to auto-mask zero denominators.
85///
86/// Requires same shape — broadcasting domain-aware ops is left as a
87/// follow-up because the per-position domain check interacts awkwardly
88/// with broadcast strides.
89pub fn masked_binary_domain<T, D, F, Dom>(
90    a: &MaskedArray<T, D>,
91    b: &MaskedArray<T, D>,
92    f: F,
93    in_domain: Dom,
94    op_name: &str,
95) -> FerrayResult<MaskedArray<T, D>>
96where
97    T: Element + Copy,
98    D: Dimension,
99    F: Fn(T, T) -> T,
100    Dom: Fn(T, T) -> bool,
101{
102    if a.shape() != b.shape() {
103        return Err(FerrayError::shape_mismatch(format!(
104            "{op_name}: shapes {:?} and {:?} differ (broadcasting not supported for domain-aware ops)",
105            a.shape(),
106            b.shape()
107        )));
108    }
109    let fill = a.fill_value();
110    let n = a.size();
111    let mut data_out: Vec<T> = Vec::with_capacity(n);
112    let mut mask_out: Vec<bool> = Vec::with_capacity(n);
113    for (((&x, &y), &ma_bit), &mb_bit) in a
114        .data()
115        .iter()
116        .zip(b.data().iter())
117        .zip(a.mask().iter())
118        .zip(b.mask().iter())
119    {
120        // Mask the result if either input is already masked OR the
121        // domain predicate rejects the pair.
122        let should_mask = ma_bit || mb_bit || !in_domain(x, y);
123        if should_mask {
124            data_out.push(fill);
125            mask_out.push(true);
126        } else {
127            data_out.push(f(x, y));
128            mask_out.push(false);
129        }
130    }
131    let data_arr = Array::from_vec(a.dim().clone(), data_out)?;
132    let mask_arr = Array::from_vec(a.dim().clone(), mask_out)?;
133    let mut out = MaskedArray::new(data_arr, mask_arr)?;
134    out.set_fill_value(fill);
135    Ok(out)
136}
137
138/// Natural log with auto-masking of non-positive inputs.
139///
140/// Equivalent to `numpy.ma.log`. Any unmasked element `x <= 0` is
141/// added to the result mask and replaced with the fill value in the
142/// result data.
143pub fn log_domain<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
144where
145    T: Element + Float,
146    D: Dimension,
147{
148    let zero = <T as Element>::zero();
149    masked_unary_domain(ma, T::ln, move |x| x > zero)
150}
151
152/// Base-2 log with auto-masking of non-positive inputs.
153pub fn log2_domain<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
154where
155    T: Element + Float,
156    D: Dimension,
157{
158    let zero = <T as Element>::zero();
159    masked_unary_domain(ma, T::log2, move |x| x > zero)
160}
161
162/// Base-10 log with auto-masking of non-positive inputs.
163pub fn log10_domain<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
164where
165    T: Element + Float,
166    D: Dimension,
167{
168    let zero = <T as Element>::zero();
169    masked_unary_domain(ma, T::log10, move |x| x > zero)
170}
171
172/// Square root with auto-masking of negative inputs.
173///
174/// Equivalent to `numpy.ma.sqrt`. Any unmasked element `x < 0` is
175/// added to the result mask.
176pub fn sqrt_domain<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
177where
178    T: Element + Float,
179    D: Dimension,
180{
181    let zero = <T as Element>::zero();
182    masked_unary_domain(ma, T::sqrt, move |x| x >= zero)
183}
184
185/// Arc sine with auto-masking of out-of-domain (`|x| > 1`) inputs.
186///
187/// Equivalent to `numpy.ma.arcsin`.
188pub fn arcsin_domain<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
189where
190    T: Element + Float,
191    D: Dimension,
192{
193    let one = <T as Element>::one();
194    masked_unary_domain(ma, T::asin, move |x| x.abs() <= one)
195}
196
197/// Arc cosine with auto-masking of out-of-domain (`|x| > 1`) inputs.
198///
199/// Equivalent to `numpy.ma.arccos`.
200pub fn arccos_domain<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
201where
202    T: Element + Float,
203    D: Dimension,
204{
205    let one = <T as Element>::one();
206    masked_unary_domain(ma, T::acos, move |x| x.abs() <= one)
207}
208
209/// Inverse hyperbolic cosine with auto-masking of inputs `< 1`.
210///
211/// Equivalent to `numpy.ma.arccosh`.
212pub fn arccosh_domain<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
213where
214    T: Element + Float,
215    D: Dimension,
216{
217    let one = <T as Element>::one();
218    masked_unary_domain(ma, T::acosh, move |x| x >= one)
219}
220
221/// Inverse hyperbolic tangent with auto-masking of inputs `|x| >= 1`.
222///
223/// Equivalent to `numpy.ma.arctanh`.
224pub fn arctanh_domain<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
225where
226    T: Element + Float,
227    D: Dimension,
228{
229    let one = <T as Element>::one();
230    masked_unary_domain(ma, T::atanh, move |x| x.abs() < one)
231}
232
233/// Division with auto-masking of zero denominators.
234///
235/// Equivalent to `numpy.ma.divide`. Any position where the denominator
236/// is exactly zero is added to the result mask.
237pub fn divide_domain<T, D>(
238    a: &MaskedArray<T, D>,
239    b: &MaskedArray<T, D>,
240) -> FerrayResult<MaskedArray<T, D>>
241where
242    T: Element + Float,
243    D: Dimension,
244{
245    let zero = <T as Element>::zero();
246    masked_binary_domain(a, b, |x, y| x / y, move |_x, y| y != zero, "divide_domain")
247}
248
249// ---------------------------------------------------------------------------
250// Generic ufunc wrappers (#513)
251//
252// Instead of the hand-maintained list of ~20 per-function wrappers below,
253// `masked_unary` and `masked_binary` let callers plug any closure and get
254// mask propagation + fill-value handling for free. The named per-ufunc
255// wrappers (sin, cos, log, sqrt, …) are kept as ergonomic shorthand for
256// the hot cases.
257// ---------------------------------------------------------------------------
258
259/// Apply any unary function to a masked array, propagating the mask.
260///
261/// Equivalent to `numpy.ma.<unary>(a)` for any unary `T -> T` closure.
262/// Masked elements are skipped; their output positions carry the
263/// masked array's `fill_value` in the result data and remain masked.
264///
265/// This is the generic escape hatch for ufuncs that don't have a
266/// dedicated wrapper in this module. Prefer a named wrapper
267/// (`sin`, `log`, `sqrt`, …) when one exists — it's the same
268/// implementation under the hood but documents intent better.
269///
270/// # Errors
271/// Returns the underlying `from_vec` shape mismatch if the result
272/// buffer is malformed (not possible in practice).
273pub fn masked_unary<T, D, F>(ma: &MaskedArray<T, D>, f: F) -> FerrayResult<MaskedArray<T, D>>
274where
275    T: Element + Copy,
276    D: Dimension,
277    F: Fn(T) -> T,
278{
279    masked_unary_op(ma, f)
280}
281
282/// Apply any binary function to two masked arrays, propagating the
283/// union of their masks.
284///
285/// Equivalent to `numpy.ma.<binary>(a, b)` for any binary
286/// `(T, T) -> T` closure. The result's mask is the elementwise OR of
287/// the two inputs' masks; masked positions in the result data carry
288/// the receiver's `fill_value`. Both inputs are broadcast to a common
289/// shape via `NumPy` rules on the slow path — the same broadcast
290/// machinery the named `add`/`multiply`/etc. wrappers use.
291///
292/// # Errors
293/// Returns `FerrayError::ShapeMismatch` if the shapes are not
294/// broadcast-compatible.
295pub fn masked_binary<T, D, F>(
296    a: &MaskedArray<T, D>,
297    b: &MaskedArray<T, D>,
298    f: F,
299    op_name: &str,
300) -> FerrayResult<MaskedArray<T, D>>
301where
302    T: Element + Copy,
303    D: Dimension,
304    F: Fn(T, T) -> T,
305{
306    masked_binary_op(a, b, f, op_name)
307}
308
309// ---------------------------------------------------------------------------
310// Trigonometric ufuncs
311// ---------------------------------------------------------------------------
312
313/// Elementwise sine on a masked array. Masked elements are skipped.
314///
315/// # Errors
316/// Returns an error only for internal failures.
317pub fn sin<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
318where
319    T: Element + Float,
320    D: Dimension,
321{
322    masked_unary_op(ma, T::sin)
323}
324
325/// Elementwise cosine on a masked array. Masked elements are skipped.
326///
327/// # Errors
328/// Returns an error only for internal failures.
329pub fn cos<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
330where
331    T: Element + Float,
332    D: Dimension,
333{
334    masked_unary_op(ma, T::cos)
335}
336
337/// Elementwise tangent on a masked array. Masked elements are skipped.
338///
339/// # Errors
340/// Returns an error only for internal failures.
341pub fn tan<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
342where
343    T: Element + Float,
344    D: Dimension,
345{
346    masked_unary_op(ma, T::tan)
347}
348
349/// Elementwise arc sine on a masked array. Masked elements are skipped.
350///
351/// # Errors
352/// Returns an error only for internal failures.
353pub fn arcsin<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
354where
355    T: Element + Float,
356    D: Dimension,
357{
358    masked_unary_op(ma, T::asin)
359}
360
361/// Elementwise arc cosine on a masked array. Masked elements are skipped.
362///
363/// # Errors
364/// Returns an error only for internal failures.
365pub fn arccos<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
366where
367    T: Element + Float,
368    D: Dimension,
369{
370    masked_unary_op(ma, T::acos)
371}
372
373/// Elementwise arc tangent on a masked array. Masked elements are skipped.
374///
375/// # Errors
376/// Returns an error only for internal failures.
377pub fn arctan<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
378where
379    T: Element + Float,
380    D: Dimension,
381{
382    masked_unary_op(ma, T::atan)
383}
384
385// ---------------------------------------------------------------------------
386// Exponential / logarithmic
387// ---------------------------------------------------------------------------
388
389/// Elementwise exponential on a masked array. Masked elements are skipped.
390///
391/// # Errors
392/// Returns an error only for internal failures.
393pub fn exp<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
394where
395    T: Element + Float,
396    D: Dimension,
397{
398    masked_unary_op(ma, T::exp)
399}
400
401/// Elementwise base-2 exponential on a masked array. Masked elements are skipped.
402///
403/// # Errors
404/// Returns an error only for internal failures.
405pub fn exp2<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
406where
407    T: Element + Float,
408    D: Dimension,
409{
410    masked_unary_op(ma, T::exp2)
411}
412
413/// Elementwise natural logarithm on a masked array. Masked elements are skipped.
414///
415/// # Errors
416/// Returns an error only for internal failures.
417pub fn log<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
418where
419    T: Element + Float,
420    D: Dimension,
421{
422    masked_unary_op(ma, T::ln)
423}
424
425/// Elementwise base-2 logarithm on a masked array. Masked elements are skipped.
426///
427/// # Errors
428/// Returns an error only for internal failures.
429pub fn log2<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
430where
431    T: Element + Float,
432    D: Dimension,
433{
434    masked_unary_op(ma, T::log2)
435}
436
437/// Elementwise base-10 logarithm on a masked array. Masked elements are skipped.
438///
439/// # Errors
440/// Returns an error only for internal failures.
441pub fn log10<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
442where
443    T: Element + Float,
444    D: Dimension,
445{
446    masked_unary_op(ma, T::log10)
447}
448
449// ---------------------------------------------------------------------------
450// Rounding
451// ---------------------------------------------------------------------------
452
453/// Elementwise floor on a masked array. Masked elements are skipped.
454///
455/// # Errors
456/// Returns an error only for internal failures.
457pub fn floor<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
458where
459    T: Element + Float,
460    D: Dimension,
461{
462    masked_unary_op(ma, T::floor)
463}
464
465/// Elementwise ceiling on a masked array. Masked elements are skipped.
466///
467/// # Errors
468/// Returns an error only for internal failures.
469pub fn ceil<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
470where
471    T: Element + Float,
472    D: Dimension,
473{
474    masked_unary_op(ma, T::ceil)
475}
476
477// ---------------------------------------------------------------------------
478// Arithmetic ufuncs
479// ---------------------------------------------------------------------------
480
481/// Elementwise square root on a masked array. Masked elements are skipped.
482///
483/// # Errors
484/// Returns an error only for internal failures.
485pub fn sqrt<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
486where
487    T: Element + Float,
488    D: Dimension,
489{
490    masked_unary_op(ma, T::sqrt)
491}
492
493/// Elementwise absolute value on a masked array. Masked elements are skipped.
494///
495/// # Errors
496/// Returns an error only for internal failures.
497pub fn absolute<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
498where
499    T: Element + Float,
500    D: Dimension,
501{
502    masked_unary_op(ma, T::abs)
503}
504
505/// Elementwise negation on a masked array. Masked elements are skipped.
506///
507/// # Errors
508/// Returns an error only for internal failures.
509pub fn negative<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
510where
511    T: Element + Float,
512    D: Dimension,
513{
514    masked_unary_op(ma, T::neg)
515}
516
517/// Elementwise reciprocal on a masked array. Masked elements are skipped.
518///
519/// # Errors
520/// Returns an error only for internal failures.
521pub fn reciprocal<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
522where
523    T: Element + Float,
524    D: Dimension,
525{
526    masked_unary_op(ma, T::recip)
527}
528
529/// Elementwise square on a masked array. Masked elements are skipped.
530///
531/// # Errors
532/// Returns an error only for internal failures.
533pub fn square<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
534where
535    T: Element + Float,
536    D: Dimension,
537{
538    masked_unary_op(ma, |v| v * v)
539}
540
541/// Elementwise hyperbolic sine on a masked array.
542pub fn sinh<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
543where
544    T: Element + Float,
545    D: Dimension,
546{
547    masked_unary_op(ma, T::sinh)
548}
549
550/// Elementwise hyperbolic cosine on a masked array.
551pub fn cosh<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
552where
553    T: Element + Float,
554    D: Dimension,
555{
556    masked_unary_op(ma, T::cosh)
557}
558
559/// Elementwise hyperbolic tangent on a masked array.
560pub fn tanh<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
561where
562    T: Element + Float,
563    D: Dimension,
564{
565    masked_unary_op(ma, T::tanh)
566}
567
568/// Elementwise inverse hyperbolic sine on a masked array.
569pub fn arcsinh<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
570where
571    T: Element + Float,
572    D: Dimension,
573{
574    masked_unary_op(ma, T::asinh)
575}
576
577/// Elementwise inverse hyperbolic cosine on a masked array.
578pub fn arccosh<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
579where
580    T: Element + Float,
581    D: Dimension,
582{
583    masked_unary_op(ma, T::acosh)
584}
585
586/// Elementwise inverse hyperbolic tangent on a masked array.
587pub fn arctanh<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
588where
589    T: Element + Float,
590    D: Dimension,
591{
592    masked_unary_op(ma, T::atanh)
593}
594
595/// Elementwise `log(1 + x)` on a masked array.
596pub fn log1p<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
597where
598    T: Element + Float,
599    D: Dimension,
600{
601    masked_unary_op(ma, T::ln_1p)
602}
603
604/// Elementwise `exp(x) - 1` on a masked array.
605pub fn expm1<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
606where
607    T: Element + Float,
608    D: Dimension,
609{
610    masked_unary_op(ma, T::exp_m1)
611}
612
613/// Elementwise `trunc` (round toward zero) on a masked array.
614pub fn trunc<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
615where
616    T: Element + Float,
617    D: Dimension,
618{
619    masked_unary_op(ma, T::trunc)
620}
621
622/// Elementwise `round` (round to nearest, halves to even) on a masked
623/// array.
624pub fn round<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
625where
626    T: Element + Float,
627    D: Dimension,
628{
629    masked_unary_op(ma, T::round)
630}
631
632/// Elementwise `sign` (-1, 0, or 1 per element) on a masked array.
633///
634/// Matches `NumPy`'s `np.sign` rather than Rust's `f64::signum`: exact
635/// zero (both +0.0 and -0.0) returns 0.0, not +1.0. NaN propagates
636/// to NaN. For any other finite or infinite value the result is
637/// `-1.0` if negative and `+1.0` if positive.
638pub fn sign<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
639where
640    T: Element + Float,
641    D: Dimension,
642{
643    let zero = <T as Element>::zero();
644    let one = <T as Element>::one();
645    masked_unary_op(ma, move |v| {
646        if v.is_nan() {
647            v
648        } else if v == zero {
649            zero
650        } else if v < zero {
651            -one
652        } else {
653            one
654        }
655    })
656}
657
658// ---------------------------------------------------------------------------
659// Binary ufuncs on two MaskedArrays
660// ---------------------------------------------------------------------------
661
662/// Elementwise addition of two masked arrays with mask propagation.
663///
664/// # Errors
665/// Returns `FerrayError::ShapeMismatch` if shapes differ.
666pub fn add<T, D>(a: &MaskedArray<T, D>, b: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
667where
668    T: Element + Float,
669    D: Dimension,
670{
671    masked_binary_op(a, b, |x, y| x + y, "add")
672}
673
674/// Elementwise subtraction of two masked arrays with mask propagation.
675///
676/// # Errors
677/// Returns `FerrayError::ShapeMismatch` if shapes differ.
678pub fn subtract<T, D>(
679    a: &MaskedArray<T, D>,
680    b: &MaskedArray<T, D>,
681) -> FerrayResult<MaskedArray<T, D>>
682where
683    T: Element + Float,
684    D: Dimension,
685{
686    masked_binary_op(a, b, |x, y| x - y, "subtract")
687}
688
689/// Elementwise multiplication of two masked arrays with mask propagation.
690///
691/// # Errors
692/// Returns `FerrayError::ShapeMismatch` if shapes differ.
693pub fn multiply<T, D>(
694    a: &MaskedArray<T, D>,
695    b: &MaskedArray<T, D>,
696) -> FerrayResult<MaskedArray<T, D>>
697where
698    T: Element + Float,
699    D: Dimension,
700{
701    masked_binary_op(a, b, |x, y| x * y, "multiply")
702}
703
704/// Elementwise division of two masked arrays with mask propagation.
705///
706/// # Errors
707/// Returns `FerrayError::ShapeMismatch` if shapes differ.
708pub fn divide<T, D>(a: &MaskedArray<T, D>, b: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
709where
710    T: Element + Float,
711    D: Dimension,
712{
713    masked_binary_op(a, b, |x, y| x / y, "divide")
714}
715
716/// Elementwise power of two masked arrays with mask propagation.
717///
718/// # Errors
719/// Returns `FerrayError::ShapeMismatch` if shapes differ.
720pub fn power<T, D>(a: &MaskedArray<T, D>, b: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
721where
722    T: Element + Float,
723    D: Dimension,
724{
725    masked_binary_op(a, b, T::powf, "power")
726}
727
728#[cfg(test)]
729mod tests {
730    use super::*;
731    use ferray_core::Array;
732    use ferray_core::dimension::Ix1;
733
734    fn make_ma(data: Vec<f64>, mask: Vec<bool>) -> MaskedArray<f64, Ix1> {
735        let n = data.len();
736        let d = Array::<f64, Ix1>::from_vec(Ix1::new([n]), data).unwrap();
737        let m = Array::<bool, Ix1>::from_vec(Ix1::new([n]), mask).unwrap();
738        MaskedArray::new(d, m).unwrap()
739    }
740
741    // ---- generic masked_unary / masked_binary (#513) ----
742
743    #[test]
744    fn masked_unary_applies_closure_to_unmasked_only() {
745        // Closure that's not exposed as a named wrapper: `x * 10 + 1`.
746        let ma = make_ma(vec![1.0, 2.0, 3.0, 4.0], vec![false, true, false, false]);
747        let r = masked_unary(&ma, |x| x.mul_add(10.0, 1.0)).unwrap();
748        let d: Vec<f64> = r.data().iter().copied().collect();
749        // Position 1 is masked → retains the fill_value (default: 0.0).
750        assert_eq!(d, vec![11.0, 0.0, 31.0, 41.0]);
751        // Mask is preserved identically.
752        let m: Vec<bool> = r.mask().iter().copied().collect();
753        assert_eq!(m, vec![false, true, false, false]);
754    }
755
756    #[test]
757    fn masked_unary_preserves_fill_value() {
758        let mut ma = make_ma(vec![1.0, 2.0, 3.0], vec![false, true, false]);
759        ma.set_fill_value(-99.0);
760        let r = masked_unary(&ma, f64::sqrt).unwrap();
761        let d: Vec<f64> = r.data().iter().copied().collect();
762        assert_eq!(d[1], -99.0);
763        assert_eq!(r.fill_value(), -99.0);
764    }
765
766    #[test]
767    fn masked_binary_mask_union_and_custom_closure() {
768        // Custom binary op: `a * 2 + b`.
769        let a = make_ma(vec![1.0, 2.0, 3.0, 4.0], vec![false, true, false, false]);
770        let b = make_ma(
771            vec![10.0, 20.0, 30.0, 40.0],
772            vec![false, false, true, false],
773        );
774        let r = masked_binary(&a, &b, |x, y| x.mul_add(2.0, y), "test_op").unwrap();
775        let d: Vec<f64> = r.data().iter().copied().collect();
776        // Position 0: 2*1 + 10 = 12; position 1: masked; position 2: masked;
777        // position 3: 2*4 + 40 = 48.
778        assert_eq!(d[0], 12.0);
779        assert_eq!(d[3], 48.0);
780        // Mask is the OR of both inputs.
781        let m: Vec<bool> = r.mask().iter().copied().collect();
782        assert_eq!(m, vec![false, true, true, false]);
783    }
784
785    #[test]
786    fn masked_binary_broadcasts_cross_shape() {
787        // masked_binary inherits broadcasting from masked_binary_op.
788        use ferray_core::dimension::Ix2;
789        let d1 = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
790            .unwrap();
791        let m1 = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![false; 6]).unwrap();
792        let a = MaskedArray::new(d1, m1).unwrap();
793        let d2 = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
794        let m2 = Array::<bool, Ix2>::from_vec(Ix2::new([1, 3]), vec![false; 3]).unwrap();
795        let b = MaskedArray::new(d2, m2).unwrap();
796        let r = masked_binary(&a, &b, |x, y| x + y, "add_broadcast").unwrap();
797        let d: Vec<f64> = r.data().iter().copied().collect();
798        assert_eq!(d, vec![11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
799    }
800
801    // ---- new named ufunc wrappers ----
802
803    #[test]
804    fn sinh_cosh_tanh_skip_masked() {
805        let ma = make_ma(vec![0.0, 1.0, -1.0], vec![false, true, false]);
806        let sh = sinh(&ma).unwrap();
807        let ch = cosh(&ma).unwrap();
808        let th = tanh(&ma).unwrap();
809        let sd: Vec<f64> = sh.data().iter().copied().collect();
810        let cd: Vec<f64> = ch.data().iter().copied().collect();
811        let td: Vec<f64> = th.data().iter().copied().collect();
812        assert!((sd[0] - 0.0).abs() < 1e-12);
813        assert!((cd[0] - 1.0).abs() < 1e-12);
814        assert!((td[0] - 0.0).abs() < 1e-12);
815        // Position 1 masked → fill value (0.0).
816        assert_eq!(sd[1], 0.0);
817        assert_eq!(cd[1], 0.0);
818        assert_eq!(td[1], 0.0);
819        assert!((sd[2] - (-1.0_f64).sinh()).abs() < 1e-12);
820    }
821
822    #[test]
823    fn log1p_expm1_are_precise_near_zero() {
824        let ma = make_ma(vec![1e-15, 0.5, -0.5], vec![false, true, false]);
825        let l = log1p(&ma).unwrap();
826        let e = expm1(&ma).unwrap();
827        let ld: Vec<f64> = l.data().iter().copied().collect();
828        let ed: Vec<f64> = e.data().iter().copied().collect();
829        // log1p(1e-15) ≈ 1e-15 (not lost to floating point)
830        assert!((ld[0] - 1e-15_f64.ln_1p()).abs() < 1e-25);
831        assert!((ed[2] - (-0.5_f64).exp_m1()).abs() < 1e-12);
832    }
833
834    #[test]
835    fn trunc_round_sign_basic() {
836        let ma = make_ma(vec![1.7, -2.5, 0.0, -3.2], vec![false, false, false, false]);
837        let t = trunc(&ma).unwrap();
838        let r = round(&ma).unwrap();
839        let s = sign(&ma).unwrap();
840        let td: Vec<f64> = t.data().iter().copied().collect();
841        let rd: Vec<f64> = r.data().iter().copied().collect();
842        let sd: Vec<f64> = s.data().iter().copied().collect();
843        assert_eq!(td, vec![1.0, -2.0, 0.0, -3.0]);
844        // f64::round on -2.5 rounds away from zero to -3 in Rust stdlib
845        // (NOT round-half-to-even). Hand-check matches that behavior.
846        assert_eq!(rd, vec![2.0, -3.0, 0.0, -3.0]);
847        assert_eq!(sd, vec![1.0, -1.0, 0.0, -1.0]);
848    }
849
850    #[test]
851    fn arcsinh_arccosh_arctanh_masked_positions_use_fill() {
852        let ma = make_ma(vec![0.0, 2.0, 0.5, -1.0], vec![false, true, false, true]);
853        let a = arcsinh(&ma).unwrap();
854        let ad: Vec<f64> = a.data().iter().copied().collect();
855        assert!((ad[0] - 0.0_f64.asinh()).abs() < 1e-12);
856        // Masked → 0.0 fill
857        assert_eq!(ad[1], 0.0);
858        assert_eq!(ad[3], 0.0);
859
860        // arccosh needs x >= 1, so use a different input for that test.
861        let ma2 = make_ma(vec![1.0, 2.0, 5.0], vec![false, false, false]);
862        let ac = arccosh(&ma2).unwrap();
863        let acd: Vec<f64> = ac.data().iter().copied().collect();
864        assert!((acd[0] - 0.0).abs() < 1e-12); // arccosh(1) = 0
865        assert!((acd[1] - 2.0_f64.acosh()).abs() < 1e-12);
866
867        // arctanh needs |x| < 1.
868        let ma3 = make_ma(vec![0.0, 0.5, -0.5], vec![false, false, false]);
869        let at = arctanh(&ma3).unwrap();
870        let atd: Vec<f64> = at.data().iter().copied().collect();
871        assert!((atd[1] - 0.5_f64.atanh()).abs() < 1e-12);
872    }
873
874    #[test]
875    fn masked_unary_and_named_sin_agree() {
876        // Generic path and named wrapper must produce identical results.
877        let ma = make_ma(vec![0.0, 1.0, 2.0, 3.0], vec![false, true, false, false]);
878        let via_named = sin(&ma).unwrap();
879        let via_generic = masked_unary(&ma, f64::sin).unwrap();
880        let vn: Vec<f64> = via_named.data().iter().copied().collect();
881        let vg: Vec<f64> = via_generic.data().iter().copied().collect();
882        assert_eq!(vn, vg);
883        let mn: Vec<bool> = via_named.mask().iter().copied().collect();
884        let mg: Vec<bool> = via_generic.mask().iter().copied().collect();
885        assert_eq!(mn, mg);
886    }
887
888    #[test]
889    fn masked_binary_and_named_add_agree() {
890        let a = make_ma(vec![1.0, 2.0, 3.0], vec![false, true, false]);
891        let b = make_ma(vec![10.0, 20.0, 30.0], vec![false, false, true]);
892        let via_named = add(&a, &b).unwrap();
893        let via_generic = masked_binary(&a, &b, |x, y| x + y, "add_generic").unwrap();
894        let vn: Vec<f64> = via_named.data().iter().copied().collect();
895        let vg: Vec<f64> = via_generic.data().iter().copied().collect();
896        assert_eq!(vn, vg);
897    }
898
899    // ---- domain-aware ufuncs (#503) ----
900
901    #[test]
902    fn log_domain_masks_non_positive_inputs() {
903        // [1.0, 2.0, -1.0, 0.0, 3.0] with no existing mask
904        let ma = make_ma(
905            vec![1.0, 2.0, -1.0, 0.0, 3.0],
906            vec![false, false, false, false, false],
907        );
908        let r = log_domain(&ma).unwrap();
909        let m: Vec<bool> = r.mask().iter().copied().collect();
910        // Positions 2 (-1) and 3 (0) are out of domain → masked.
911        assert_eq!(m, vec![false, false, true, true, false]);
912        // Unmasked positions have correct log values.
913        let d: Vec<f64> = r.data().iter().copied().collect();
914        assert!((d[0] - 0.0).abs() < 1e-12);
915        assert!((d[1] - 2.0_f64.ln()).abs() < 1e-12);
916        assert!((d[4] - 3.0_f64.ln()).abs() < 1e-12);
917        // Masked positions carry the fill value (0.0 default).
918        assert_eq!(d[2], 0.0);
919        assert_eq!(d[3], 0.0);
920    }
921
922    #[test]
923    fn log_domain_preserves_existing_mask() {
924        // Position 1 is already masked; position 3 goes to masked via domain.
925        let ma = make_ma(vec![1.0, 2.0, 5.0, -1.0], vec![false, true, false, false]);
926        let r = log_domain(&ma).unwrap();
927        let m: Vec<bool> = r.mask().iter().copied().collect();
928        assert_eq!(m, vec![false, true, false, true]);
929    }
930
931    #[test]
932    fn log_domain_vs_plain_log_on_negative_input() {
933        // Plain `log` produces NaN at the negative position; `log_domain`
934        // masks it and substitutes the fill value.
935        let ma = make_ma(vec![1.0, -2.0, 3.0], vec![false, false, false]);
936        let plain = log(&ma).unwrap();
937        let domain = log_domain(&ma).unwrap();
938        let pd: Vec<f64> = plain.data().iter().copied().collect();
939        let dd: Vec<f64> = domain.data().iter().copied().collect();
940        // Plain: position 1 is NaN.
941        assert!(pd[1].is_nan());
942        // Domain: position 1 is the fill value, and the mask is set.
943        assert_eq!(dd[1], 0.0);
944        assert!(domain.mask().as_slice().unwrap()[1]);
945    }
946
947    #[test]
948    fn sqrt_domain_masks_negative_inputs() {
949        let ma = make_ma(
950            vec![0.0, 1.0, 4.0, -9.0, -1e-10],
951            vec![false, false, false, false, false],
952        );
953        let r = sqrt_domain(&ma).unwrap();
954        let m: Vec<bool> = r.mask().iter().copied().collect();
955        // 0.0 and positive values pass; negatives are masked.
956        assert_eq!(m, vec![false, false, false, true, true]);
957        let d: Vec<f64> = r.data().iter().copied().collect();
958        assert!((d[0] - 0.0).abs() < 1e-12);
959        assert!((d[1] - 1.0).abs() < 1e-12);
960        assert!((d[2] - 2.0).abs() < 1e-12);
961    }
962
963    #[test]
964    fn arcsin_domain_masks_out_of_range() {
965        let ma = make_ma(
966            vec![-1.5, -0.5, 0.0, 0.5, 1.5],
967            vec![false, false, false, false, false],
968        );
969        let r = arcsin_domain(&ma).unwrap();
970        let m: Vec<bool> = r.mask().iter().copied().collect();
971        // |x| > 1 → masked; |x| <= 1 passes.
972        assert_eq!(m, vec![true, false, false, false, true]);
973        let d: Vec<f64> = r.data().iter().copied().collect();
974        assert!((d[1] - (-0.5_f64).asin()).abs() < 1e-12);
975        assert!((d[2] - 0.0).abs() < 1e-12);
976    }
977
978    #[test]
979    fn arccos_domain_masks_out_of_range() {
980        let ma = make_ma(vec![-1.5, 0.0, 1.0, 2.0], vec![false, false, false, false]);
981        let r = arccos_domain(&ma).unwrap();
982        let m: Vec<bool> = r.mask().iter().copied().collect();
983        assert_eq!(m, vec![true, false, false, true]);
984    }
985
986    #[test]
987    fn arccosh_domain_masks_below_one() {
988        // arccosh domain is x >= 1.
989        let ma = make_ma(vec![0.5, 1.0, 2.0, 10.0], vec![false, false, false, false]);
990        let r = arccosh_domain(&ma).unwrap();
991        let m: Vec<bool> = r.mask().iter().copied().collect();
992        assert_eq!(m, vec![true, false, false, false]);
993        let d: Vec<f64> = r.data().iter().copied().collect();
994        assert!((d[1] - 0.0).abs() < 1e-12); // acosh(1) = 0
995        assert!((d[2] - 2.0_f64.acosh()).abs() < 1e-12);
996    }
997
998    #[test]
999    fn arctanh_domain_masks_boundary_and_beyond() {
1000        // arctanh domain is |x| < 1 strictly (not <=).
1001        let ma = make_ma(
1002            vec![-1.0, -0.5, 0.0, 0.5, 1.0],
1003            vec![false, false, false, false, false],
1004        );
1005        let r = arctanh_domain(&ma).unwrap();
1006        let m: Vec<bool> = r.mask().iter().copied().collect();
1007        assert_eq!(m, vec![true, false, false, false, true]);
1008    }
1009
1010    #[test]
1011    fn divide_domain_masks_zero_denominators() {
1012        let num = make_ma(vec![1.0, 2.0, 3.0, 4.0], vec![false, false, false, false]);
1013        let den = make_ma(vec![2.0, 0.0, 1.0, 0.0], vec![false, false, false, false]);
1014        let r = divide_domain(&num, &den).unwrap();
1015        let m: Vec<bool> = r.mask().iter().copied().collect();
1016        // Positions 1 and 3 have zero denominators → masked.
1017        assert_eq!(m, vec![false, true, false, true]);
1018        let d: Vec<f64> = r.data().iter().copied().collect();
1019        assert!((d[0] - 0.5).abs() < 1e-12);
1020        assert!((d[2] - 3.0).abs() < 1e-12);
1021    }
1022
1023    #[test]
1024    fn divide_domain_preserves_numerator_and_denominator_masks() {
1025        let num = make_ma(vec![1.0, 2.0, 3.0, 4.0], vec![false, true, false, false]);
1026        let den = make_ma(vec![2.0, 5.0, 0.0, 4.0], vec![false, false, false, true]);
1027        let r = divide_domain(&num, &den).unwrap();
1028        let m: Vec<bool> = r.mask().iter().copied().collect();
1029        // pos 1: num masked; pos 2: zero denom; pos 3: denom masked.
1030        assert_eq!(m, vec![false, true, true, true]);
1031    }
1032
1033    #[test]
1034    fn masked_unary_domain_generic_path() {
1035        // Custom domain: only positions where x is even allowed.
1036        let ma = make_ma(
1037            vec![1.0, 2.0, 3.0, 4.0, 5.0],
1038            vec![false, false, false, false, false],
1039        );
1040        let r = masked_unary_domain(&ma, |x| x * 2.0, |x| (x as i32) % 2 == 0).unwrap();
1041        let m: Vec<bool> = r.mask().iter().copied().collect();
1042        assert_eq!(m, vec![true, false, true, false, true]);
1043        let d: Vec<f64> = r.data().iter().copied().collect();
1044        assert_eq!(d[1], 4.0);
1045        assert_eq!(d[3], 8.0);
1046    }
1047
1048    #[test]
1049    fn masked_binary_domain_rejects_mismatched_shapes() {
1050        let a = make_ma(vec![1.0, 2.0], vec![false, false]);
1051        let b = make_ma(vec![1.0, 2.0, 3.0], vec![false, false, false]);
1052        assert!(divide_domain(&a, &b).is_err());
1053    }
1054
1055    #[test]
1056    fn log_domain_with_custom_fill_value() {
1057        let mut ma = make_ma(vec![1.0, -1.0, 2.0], vec![false, false, false]);
1058        ma.set_fill_value(-999.0);
1059        let r = log_domain(&ma).unwrap();
1060        let d: Vec<f64> = r.data().iter().copied().collect();
1061        // Position 1 is auto-masked → carries fill value.
1062        assert_eq!(d[1], -999.0);
1063        assert_eq!(r.fill_value(), -999.0);
1064    }
1065}