prosia_extensions/
lib.rs

1pub mod io {
2    use std::fs::File;
3    use std::io::BufRead;
4    use std::io::BufReader;
5    use std::io::Error;
6    use std::path::Path;
7
8    pub fn load_file_as_str<P: AsRef<Path>>(path: P) -> Result<Vec<String>, Error> {
9        let file = File::open(path)?;
10        let reader = BufReader::new(file);
11        let all_lines: Vec<String> = reader.lines().map_while(Result::ok).collect();
12
13        Ok(all_lines)
14    }
15}
16
17pub mod arrays {
18    use std::fmt;
19
20    use ndarray::s;
21    use ndarray::{ArrayBase, Axis, Ix1, OwnedRepr};
22    use ndarray::{Data, Dimension, RemoveAxis};
23
24    use crate::types::RVector;
25
26    /// Error type for extrema operations
27    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
28    pub enum ExtremaError {
29        EmptyArray,
30        UndefinedOrder, // e.g., NaN encountered
31    }
32
33    impl fmt::Display for ExtremaError {
34        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
35            match self {
36                ExtremaError::EmptyArray => write!(f, "cannot compute extrema of empty array"),
37                ExtremaError::UndefinedOrder => {
38                    write!(f, "undefined order: encountered NaN or incomparable values")
39                }
40            }
41        }
42    }
43
44    impl std::error::Error for ExtremaError {}
45
46    /// Extension trait providing convenience methods for computing extrema
47    /// (minima, maxima) and their indices on [ndarray::ArrayBase] values.
48    pub trait ArrayExtrema<T, D>
49    where
50        D: Dimension,
51    {
52        /// Returns the maximum value in the array.
53        ///
54        /// Returns `Err(ExtremaError::UndefinedOrder)` if any NaN values are encountered,
55        /// or `Err(ExtremaError::EmptyArray)` if the array is empty.
56        ///
57        /// # Examples
58        ///
59        /// ```
60        /// use ndarray::array;
61        /// use prosia_extensions::arrays::ArrayExtrema;
62        /// use prosia_extensions::arrays::ExtremaError;
63        ///
64        /// let a = array![1, 3, 2];
65        /// assert_eq!(a.maxval(), Ok(3));
66        ///
67        /// let empty: ndarray::Array1<i32> = ndarray::Array1::from_vec(vec![]);
68        /// assert_eq!(empty.maxval(), Err(ExtremaError::EmptyArray));
69        /// ```
70        fn maxval(&self) -> Result<T, ExtremaError>;
71
72        /// Returns the minimum value in the array.
73        ///
74        /// Returns `Err(ExtremaError::UndefinedOrder)` if any NaN values are encountered,
75        /// or `Err(ExtremaError::EmptyArray)` if the array is empty.
76        fn minval(&self) -> Result<T, ExtremaError>;
77
78        /// Returns an array of maximum values along the given axis.
79        ///
80        /// Each element of the returned array is the maximum of the slice taken
81        /// along axis. Returns `Err(ExtremaError::UndefinedOrder)` if NaN values are encountered,
82        /// or `Err(ExtremaError::EmptyArray)` if the array is empty.
83        ///
84        /// # Panics
85        /// Panics if any subview is empty, though this cannot occur if self
86        /// itself is non-empty.
87        fn maxval_along(
88            &self,
89            axis: Axis,
90        ) -> Result<ArrayBase<OwnedRepr<T>, D::Smaller>, ExtremaError>;
91
92        /// Returns an array of minimum values along the given axis.
93        ///
94        /// Each element of the returned array is the minimum of the slice taken
95        /// along axis. Returns `Err(ExtremaError::UndefinedOrder)` if NaN values are encountered,
96        /// or `Err(ExtremaError::EmptyArray)` if the array is empty.
97        fn minval_along(
98            &self,
99            axis: Axis,
100        ) -> Result<ArrayBase<OwnedRepr<T>, D::Smaller>, ExtremaError>;
101
102        /// Returns the index of the maximum element in the array.
103        ///
104        /// Returns `Err(ExtremaError::UndefinedOrder)` if any NaN values are encountered,
105        /// or `Err(ExtremaError::EmptyArray)` if the array is empty.
106        ///
107        /// The index is returned in [ndarray::Dimension::Pattern] form, which matches the array's dimensionality.
108        fn argmax(&self) -> Result<D::Pattern, ExtremaError>;
109
110        /// Returns the index of the minimum element in the array.
111        ///
112        /// Returns `Err(ExtremaError::UndefinedOrder)` if any NaN values are encountered,
113        /// or `Err(ExtremaError::EmptyArray)` if the array is empty.
114        fn argmin(&self) -> Result<D::Pattern, ExtremaError>;
115
116        /// Returns an array of indices of the maximum elements along the given axis.
117        ///
118        /// Each element in the returned array is the index (within the axis) of the
119        /// maximum value of the corresponding subview. Returns `Err(ExtremaError::UndefinedOrder)`
120        /// if NaN values are encountered, or `Err(ExtremaError::EmptyArray)` if the array is empty.
121        fn argmax_along(
122            &self,
123            axis: Axis,
124        ) -> Result<ArrayBase<OwnedRepr<usize>, D::Smaller>, ExtremaError>;
125
126        /// Returns an array of indices of the minimum elements along the given axis.
127        ///
128        /// Each element in the returned array is the index (within the axis) of the
129        /// minimum value of the corresponding subview. Returns `Err(ExtremaError::UndefinedOrder)`
130        /// if NaN values are encountered, or `Err(ExtremaError::EmptyArray)` if the array is empty.
131        fn argmin_along(
132            &self,
133            axis: Axis,
134        ) -> Result<ArrayBase<OwnedRepr<usize>, D::Smaller>, ExtremaError>;
135    }
136
137    impl<T, S, D> ArrayExtrema<T, D> for ArrayBase<S, D>
138    where
139        T: PartialOrd + Copy,
140        S: Data<Elem = T>,
141        D: Dimension + RemoveAxis,
142    {
143        /// See [ArrayExtrema::maxval].
144        fn maxval(&self) -> Result<T, ExtremaError> {
145            if self.is_empty() {
146                return Err(ExtremaError::EmptyArray);
147            }
148
149            let mut max_val = None;
150            for &val in self.iter() {
151                // Check for NaN or incomparable values by comparing with itself
152                if val.partial_cmp(&val).is_none() {
153                    return Err(ExtremaError::UndefinedOrder);
154                }
155
156                match max_val {
157                    None => max_val = Some(val),
158                    Some(current_max) => {
159                        match val.partial_cmp(&current_max) {
160                            Some(std::cmp::Ordering::Greater) => max_val = Some(val),
161                            Some(_) => {} // val <= current_max, keep current_max
162                            None => return Err(ExtremaError::UndefinedOrder), // NaN or incomparable values
163                        }
164                    }
165                }
166            }
167            Ok(max_val.unwrap()) // Safe because we checked for empty array above
168        }
169
170        /// See [ArrayExtrema::minval].
171        fn minval(&self) -> Result<T, ExtremaError> {
172            if self.is_empty() {
173                return Err(ExtremaError::EmptyArray);
174            }
175
176            let mut min_val = None;
177            for &val in self.iter() {
178                // Check for NaN or incomparable values by comparing with itself
179                if val.partial_cmp(&val).is_none() {
180                    return Err(ExtremaError::UndefinedOrder);
181                }
182
183                match min_val {
184                    None => min_val = Some(val),
185                    Some(current_min) => {
186                        match val.partial_cmp(&current_min) {
187                            Some(std::cmp::Ordering::Less) => min_val = Some(val),
188                            Some(_) => {} // val >= current_min, keep current_min
189                            None => return Err(ExtremaError::UndefinedOrder), // NaN or incomparable values
190                        }
191                    }
192                }
193            }
194            Ok(min_val.unwrap()) // Safe because we checked for empty array above
195        }
196
197        /// See [ArrayExtrema::maxval_along].
198        fn maxval_along(
199            &self,
200            axis: Axis,
201        ) -> Result<ArrayBase<OwnedRepr<T>, D::Smaller>, ExtremaError> {
202            if self.is_empty() {
203                return Err(ExtremaError::EmptyArray);
204            }
205
206            let mut result = self.map_axis(axis, |subview| -> Result<T, ExtremaError> {
207                let mut max_val = None;
208                for &val in subview.iter() {
209                    // Check for NaN or incomparable values by comparing with itself
210                    if val.partial_cmp(&val).is_none() {
211                        return Err(ExtremaError::UndefinedOrder);
212                    }
213
214                    match max_val {
215                        None => max_val = Some(val),
216                        Some(current_max) => match val.partial_cmp(&current_max) {
217                            Some(std::cmp::Ordering::Greater) => max_val = Some(val),
218                            Some(_) => {}
219                            None => return Err(ExtremaError::UndefinedOrder),
220                        },
221                    }
222                }
223                Ok(max_val.unwrap()) // Safe because subview is guaranteed non-empty
224            });
225
226            // Check if any subview computation failed
227            for elem in result.iter_mut() {
228                if let Err(err) = elem {
229                    return Err(*err);
230                }
231            }
232
233            // Convert Result<T, ExtremaError> elements to T
234            let final_result = result.map(|res| res.unwrap());
235            Ok(final_result)
236        }
237
238        /// See [ArrayExtrema::minval_along].
239        fn minval_along(
240            &self,
241            axis: Axis,
242        ) -> Result<ArrayBase<OwnedRepr<T>, D::Smaller>, ExtremaError> {
243            if self.is_empty() {
244                return Err(ExtremaError::EmptyArray);
245            }
246
247            let mut result = self.map_axis(axis, |subview| -> Result<T, ExtremaError> {
248                let mut min_val = None;
249                for &val in subview.iter() {
250                    // Check for NaN or incomparable values by comparing with itself
251                    if val.partial_cmp(&val).is_none() {
252                        return Err(ExtremaError::UndefinedOrder);
253                    }
254
255                    match min_val {
256                        None => min_val = Some(val),
257                        Some(current_min) => match val.partial_cmp(&current_min) {
258                            Some(std::cmp::Ordering::Less) => min_val = Some(val),
259                            Some(_) => {}
260                            None => return Err(ExtremaError::UndefinedOrder),
261                        },
262                    }
263                }
264                Ok(min_val.unwrap()) // Safe because subview is guaranteed non-empty
265            });
266
267            // Check if any subview computation failed
268            for elem in result.iter_mut() {
269                if let Err(err) = elem {
270                    return Err(*err);
271                }
272            }
273
274            // Convert Result<T, ExtremaError> elements to T
275            let final_result = result.map(|res| res.unwrap());
276            Ok(final_result)
277        }
278
279        /// See [ArrayExtrema::argmax].
280        fn argmax(&self) -> Result<D::Pattern, ExtremaError> {
281            if self.is_empty() {
282                return Err(ExtremaError::EmptyArray);
283            }
284
285            let mut best = None;
286
287            for (idx, &val) in self.indexed_iter() {
288                // Check for NaN or incomparable values by comparing with itself
289                if val.partial_cmp(&val).is_none() {
290                    return Err(ExtremaError::UndefinedOrder);
291                }
292
293                match best {
294                    None => best = Some((idx, val)),
295                    Some((_, best_val)) => {
296                        match val.partial_cmp(&best_val) {
297                            Some(std::cmp::Ordering::Greater) => best = Some((idx, val)),
298                            Some(_) => {} // val <= best_val, keep current best
299                            None => return Err(ExtremaError::UndefinedOrder), // NaN or incomparable values
300                        }
301                    }
302                }
303            }
304
305            Ok(best.unwrap().0) // Safe because we checked for empty array above
306        }
307
308        /// See [ArrayExtrema::argmin].
309        fn argmin(&self) -> Result<D::Pattern, ExtremaError> {
310            if self.is_empty() {
311                return Err(ExtremaError::EmptyArray);
312            }
313
314            let mut best = None;
315
316            for (idx, &val) in self.indexed_iter() {
317                // Check for NaN or incomparable values by comparing with itself
318                if val.partial_cmp(&val).is_none() {
319                    return Err(ExtremaError::UndefinedOrder);
320                }
321
322                match best {
323                    None => best = Some((idx, val)),
324                    Some((_, best_val)) => {
325                        match val.partial_cmp(&best_val) {
326                            Some(std::cmp::Ordering::Less) => best = Some((idx, val)),
327                            Some(_) => {} // val >= best_val, keep current best
328                            None => return Err(ExtremaError::UndefinedOrder), // NaN or incomparable values
329                        }
330                    }
331                }
332            }
333
334            Ok(best.unwrap().0) // Safe because we checked for empty array above
335        }
336
337        /// See [ArrayExtrema::argmax_along].
338        fn argmax_along(
339            &self,
340            axis: Axis,
341        ) -> Result<ArrayBase<OwnedRepr<usize>, D::Smaller>, ExtremaError> {
342            if self.is_empty() {
343                return Err(ExtremaError::EmptyArray);
344            }
345
346            let mut result = self.map_axis(axis, |subview| -> Result<usize, ExtremaError> {
347                let mut best = None;
348
349                for (idx, &val) in subview.indexed_iter() {
350                    // Check for NaN or incomparable values by comparing with itself
351                    if val.partial_cmp(&val).is_none() {
352                        return Err(ExtremaError::UndefinedOrder);
353                    }
354
355                    match best {
356                        None => best = Some((idx, val)),
357                        Some((_, best_val)) => match val.partial_cmp(&best_val) {
358                            Some(std::cmp::Ordering::Greater) => best = Some((idx, val)),
359                            Some(_) => {}
360                            None => return Err(ExtremaError::UndefinedOrder),
361                        },
362                    }
363                }
364
365                Ok(best.unwrap().0) // Safe because subview is guaranteed non-empty
366            });
367
368            // Check if any subview computation failed
369            for elem in result.iter_mut() {
370                if let Err(err) = elem {
371                    return Err(*err);
372                }
373            }
374
375            // Convert Result<usize, ExtremaError> elements to usize
376            let final_result = result.map(|res| res.unwrap());
377            Ok(final_result)
378        }
379
380        /// See [ArrayExtrema::argmin_along].
381        fn argmin_along(
382            &self,
383            axis: Axis,
384        ) -> Result<ArrayBase<OwnedRepr<usize>, D::Smaller>, ExtremaError> {
385            if self.is_empty() {
386                return Err(ExtremaError::EmptyArray);
387            }
388
389            let mut result = self.map_axis(axis, |subview| -> Result<usize, ExtremaError> {
390                let mut best = None;
391
392                for (idx, &val) in subview.indexed_iter() {
393                    // Check for NaN or incomparable values by comparing with itself
394                    if val.partial_cmp(&val).is_none() {
395                        return Err(ExtremaError::UndefinedOrder);
396                    }
397
398                    match best {
399                        None => best = Some((idx, val)),
400                        Some((_, best_val)) => match val.partial_cmp(&best_val) {
401                            Some(std::cmp::Ordering::Less) => best = Some((idx, val)),
402                            Some(_) => {}
403                            None => return Err(ExtremaError::UndefinedOrder),
404                        },
405                    }
406                }
407
408                Ok(best.unwrap().0) // Safe because subview is guaranteed non-empty
409            });
410
411            // Check if any subview computation failed
412            for elem in result.iter_mut() {
413                if let Err(err) = elem {
414                    return Err(*err);
415                }
416            }
417
418            // Convert Result<usize, ExtremaError> elements to usize
419            let final_result = result.map(|res| res.unwrap());
420            Ok(final_result)
421        }
422    }
423
424    /// Trait providing monotonicity checks for 1-D ndarray arrays.
425    pub trait Sequence<T>
426    where
427        T: PartialOrd + Copy,
428    {
429        /// Returns true if the array is monotonically (non-strictly) increasing.
430        fn is_monotonically_increasing(&self) -> Result<bool, ExtremaError>;
431
432        /// Returns true if the array is monotonically (non-strictly) decreasing.
433        fn is_monotonically_decreasing(&self) -> Result<bool, ExtremaError>;
434
435        /// Returns true if the array is strictly increasing.
436        fn is_strictly_increasing(&self) -> Result<bool, ExtremaError>;
437
438        /// Returns true if the array is strictly decreasing.
439        fn is_strictly_decreasing(&self) -> Result<bool, ExtremaError>;
440    }
441
442    impl<T, S> Sequence<T> for ArrayBase<S, Ix1>
443    where
444        T: PartialOrd + Copy,
445        S: Data<Elem = T>,
446    {
447        fn is_monotonically_increasing(&self) -> Result<bool, ExtremaError> {
448            if self.is_empty() {
449                return Err(ExtremaError::EmptyArray);
450            }
451
452            let mut iter = self.iter().copied();
453            let mut prev = iter.next().unwrap();
454
455            if prev.partial_cmp(&prev).is_none() {
456                return Err(ExtremaError::UndefinedOrder);
457            }
458
459            for curr in iter {
460                if curr.partial_cmp(&curr).is_none() {
461                    return Err(ExtremaError::UndefinedOrder);
462                }
463
464                match curr.partial_cmp(&prev) {
465                    Some(std::cmp::Ordering::Less) => return Ok(false),
466                    Some(_) => {} // >= ok
467                    None => return Err(ExtremaError::UndefinedOrder),
468                }
469
470                prev = curr;
471            }
472
473            Ok(true)
474        }
475
476        fn is_monotonically_decreasing(&self) -> Result<bool, ExtremaError> {
477            if self.is_empty() {
478                return Err(ExtremaError::EmptyArray);
479            }
480
481            let mut iter = self.iter().copied();
482            let mut prev = iter.next().unwrap();
483
484            if prev.partial_cmp(&prev).is_none() {
485                return Err(ExtremaError::UndefinedOrder);
486            }
487
488            for curr in iter {
489                if curr.partial_cmp(&curr).is_none() {
490                    return Err(ExtremaError::UndefinedOrder);
491                }
492
493                match curr.partial_cmp(&prev) {
494                    Some(std::cmp::Ordering::Greater) => return Ok(false),
495                    Some(_) => {} // <= ok
496                    None => return Err(ExtremaError::UndefinedOrder),
497                }
498
499                prev = curr;
500            }
501
502            Ok(true)
503        }
504
505        fn is_strictly_increasing(&self) -> Result<bool, ExtremaError> {
506            if self.is_empty() {
507                return Err(ExtremaError::EmptyArray);
508            }
509
510            let mut iter = self.iter().copied();
511            let mut prev = iter.next().unwrap();
512
513            if prev.partial_cmp(&prev).is_none() {
514                return Err(ExtremaError::UndefinedOrder);
515            }
516
517            for curr in iter {
518                if curr.partial_cmp(&curr).is_none() {
519                    return Err(ExtremaError::UndefinedOrder);
520                }
521
522                match curr.partial_cmp(&prev) {
523                    Some(std::cmp::Ordering::Greater) => {} // required
524                    _ => return Ok(false),                  // <= means not strictly
525                }
526
527                prev = curr;
528            }
529
530            Ok(true)
531        }
532
533        fn is_strictly_decreasing(&self) -> Result<bool, ExtremaError> {
534            if self.is_empty() {
535                return Err(ExtremaError::EmptyArray);
536            }
537
538            let mut iter = self.iter().copied();
539            let mut prev = iter.next().unwrap();
540
541            if prev.partial_cmp(&prev).is_none() {
542                return Err(ExtremaError::UndefinedOrder);
543            }
544
545            for curr in iter {
546                if curr.partial_cmp(&curr).is_none() {
547                    return Err(ExtremaError::UndefinedOrder);
548                }
549
550                match curr.partial_cmp(&prev) {
551                    Some(std::cmp::Ordering::Less) => {} // required
552                    _ => return Ok(false),               // >= means not strictly
553                }
554
555                prev = curr;
556            }
557
558            Ok(true)
559        }
560    }
561
562    pub trait Integrable {
563        fn trapezoid(&self, x: &RVector) -> f64;
564    }
565
566    impl Integrable for RVector {
567        fn trapezoid(&self, x: &RVector) -> f64 {
568            assert_eq!(self.len(), x.len(), "Arrays must have the same length");
569
570            let y0 = self.slice(s![..-1]);
571            let y1 = self.slice(s![1..]);
572            let x0 = x.slice(s![..-1]);
573            let x1 = x.slice(s![1..]);
574
575            ((&y0 + &y1) / 2.0 * (&x1 - &x0)).sum()
576        }
577    }
578
579    /// Returns the index of the largest element in a sorted slice that is
580    /// less than or equal to the given value.
581    ///
582    /// # Arguments
583    /// * `val` - The target value to compare against.
584    /// * `array` - A slice of `f64` values. Must be sorted in non-decreasing order.
585    ///
586    /// # Returns
587    /// * `Some(index)` if an element exists in `array` such that:
588    ///   - `array[index] <= val`
589    ///   - and `array[index + 1] > val` (or `index` is the last valid element).
590    /// * `None` if:
591    ///   - the slice is empty
592    ///   - `val` is smaller than the first element
593    ///   - `val` is greater than or equal to the last element
594    ///
595    /// # Complexity
596    /// Runs in O(log n) time using binary search via `partition_point`.
597    ///
598    /// # Example
599    /// ```
600    /// use prosia_extensions::arrays::find_index_le;
601    ///
602    /// let arr = [1.0, 2.5, 4.0, 7.0];
603    /// assert_eq!(find_index_le(3.0, &arr), Some(1)); // arr[1] = 2.5
604    /// assert_eq!(find_index_le(1.0, &arr), Some(0));
605    /// assert_eq!(find_index_le(7.0, &arr), None);
606    /// assert_eq!(find_index_le(0.5, &arr), None);
607    /// ```
608    pub fn find_index_le(val: f64, array: &[f64]) -> Option<usize> {
609        if array.is_empty() || val < array[0] || val >= array[array.len() - 1] {
610            return None;
611        }
612        let idx = array.partition_point(|&x| x <= val);
613        if idx > 0 { Some(idx - 1) } else { None }
614    }
615
616    /// Returns the index of the smallest element in a sorted slice that is
617    /// greater than or equal to the given value.
618    ///
619    /// # Arguments
620    /// * `val` - The target value to compare against.
621    /// * `array` - A slice of `f64` values. Must be sorted in non-decreasing order.
622    ///
623    /// # Returns
624    /// * `Some(index)` if an element exists in `array` such that:
625    ///   - `array[index] >= val`
626    ///   - and `array[index - 1] < val` (or `index` is the first valid element).
627    /// * `None` if:
628    ///   - the slice is empty
629    ///   - `val` is smaller than or equal to the first element
630    ///   - `val` is greater than the last element
631    ///
632    /// # Complexity
633    /// Runs in O(log n) time using binary search via `partition_point`.
634    ///
635    /// # Example
636    /// ```
637    /// use prosia_extensions::arrays::find_index_ge;
638    ///
639    /// let arr = [1.0, 2.5, 4.0, 7.0];
640    /// assert_eq!(find_index_ge(3.0, &arr), Some(2)); // arr[2] = 4.0
641    /// assert_eq!(find_index_ge(2.5, &arr), Some(1));
642    /// assert_eq!(find_index_ge(0.5, &arr), None);
643    /// assert_eq!(find_index_ge(8.0, &arr), None);
644    /// ```
645    pub fn find_index_ge(val: f64, array: &[f64]) -> Option<usize> {
646        if array.is_empty() || val > array[array.len() - 1] || val <= array[0] {
647            return None;
648        }
649        let idx = array.partition_point(|&x| x < val);
650        if idx < array.len() { Some(idx) } else { None }
651    }
652
653    /// Returns the index of the first element in a sorted slice that is
654    /// **greater than or equal to** `val`.
655    ///
656    /// This is equivalent to [`find_index_ge`] but clamps the result to valid indices.
657    /// If `val` is less than or equal to the first element, returns `0`.  
658    /// If `val` is greater than or equal to the last element, returns `array.len()`.
659    ///
660    /// # Examples
661    /// ```
662    /// use prosia_extensions::arrays::lower_bound_index;
663    ///
664    /// let arr = [1.0, 3.0, 5.0, 7.0];
665    ///
666    /// // Insert before any elements ≥ 4.0 → index 2
667    /// assert_eq!(lower_bound_index(4.0, &arr), 2);
668    ///
669    /// // Value below first → index 0
670    /// assert_eq!(lower_bound_index(0.5, &arr), 0);
671    ///
672    /// // Value above last → index len = 4
673    /// assert_eq!(lower_bound_index(10.0, &arr), arr.len());
674    /// ```
675    /// # See Also
676    /// [`upper_bound_index`], [`find_index_ge`]
677    pub fn lower_bound_index(val: f64, array: &[f64]) -> usize {
678        if array.is_empty() {
679            return 0;
680        }
681
682        if val <= array[0] {
683            return 0;
684        }
685
686        if val >= array[array.len() - 1] {
687            return array.len();
688        }
689
690        array.partition_point(|&x| x < val)
691    }
692
693    /// Returns the index of the **last element ≤ `val`** in a sorted slice.
694    ///
695    /// This behaves similarly to [`find_index_le`] but clamps the result to valid indices.
696    /// If `val` is less than the first element, returns `0`.  
697    /// If `val` is greater than or equal to the last element, returns `array.len() - 1`.
698    ///
699    /// # Examples
700    /// ```
701    /// use prosia_extensions::arrays::upper_bound_index;
702    /// let arr = [1.0, 3.0, 5.0, 7.0];
703    ///
704    /// // Largest element ≤ 4.0 is 3.0 at index 1
705    /// assert_eq!(upper_bound_index(4.0, &arr), 1);
706    ///
707    /// // Value below first → clamp to 0
708    /// assert_eq!(upper_bound_index(0.5, &arr), 0);
709    ///
710    /// // Value above last → clamp to last index = 3
711    /// assert_eq!(upper_bound_index(10.0, &arr), arr.len() - 1);
712    /// ```
713    /// # See Also
714    /// [`lower_bound_index`], [`find_index_le`]
715    pub fn upper_bound_index(val: f64, array: &[f64]) -> usize {
716        if array.is_empty() {
717            return 0;
718        }
719
720        if val < array[0] {
721            return 0;
722        }
723
724        if val >= array[array.len() - 1] {
725            return array.len() - 1;
726        }
727
728        array.partition_point(|&x| x <= val) - 1
729    }
730
731    #[cfg(test)]
732    mod tests {
733        use super::*;
734        use ndarray::{Array1, Array2, array};
735
736        #[test]
737        fn test_maxval_minval_nonempty() {
738            let a = array![1.0, 3.5, 2.2, -5.1, 7.3];
739            assert_eq!(a.maxval(), Ok(7.3));
740            assert_eq!(a.minval(), Ok(-5.1));
741        }
742
743        #[test]
744        fn test_maxval_minval_empty() {
745            let a: Array1<f64> = Array1::from_vec(vec![]);
746            assert_eq!(a.maxval(), Err(ExtremaError::EmptyArray));
747            assert_eq!(a.minval(), Err(ExtremaError::EmptyArray));
748        }
749
750        #[test]
751        fn test_maxval_along_axis0() {
752            let a = array![[1.0, 4.2, 2.1], [3.3, -1.5, 7.7]];
753            let result = a.maxval_along(Axis(0)).unwrap();
754            assert_eq!(result, array![3.3, 4.2, 7.7]);
755        }
756
757        #[test]
758        fn test_minval_along_axis0() {
759            let a = array![[1.0, 4.2, 2.1], [3.3, -1.5, 7.7]];
760            let result = a.minval_along(Axis(0)).unwrap();
761            assert_eq!(result, array![1.0, -1.5, 2.1]);
762        }
763
764        #[test]
765        fn test_maxval_along_axis1() {
766            let a = array![[1.0, 4.2, 2.1], [3.3, -1.5, 7.7]];
767            let result = a.maxval_along(Axis(1)).unwrap();
768            assert_eq!(result, array![4.2, 7.7]);
769        }
770
771        #[test]
772        fn test_minval_along_axis1() {
773            let a = array![[1.0, 4.2, 2.1], [3.3, -1.5, 7.7]];
774            let result = a.minval_along(Axis(1)).unwrap();
775            assert_eq!(result, array![1.0, -1.5]);
776        }
777
778        #[test]
779        fn test_argmax_argmin() {
780            let a = array![10.0, 3.1, 50.5, -2.2, 50.5];
781            assert_eq!(a.argmax(), Ok(2)); // first 50.5
782            assert_eq!(a.argmin(), Ok(3));
783        }
784
785        #[test]
786        fn test_argmax_along_axis0() {
787            let a = array![[1.0, 4.2, 2.1], [3.3, -1.5, 7.7]];
788            let result = a.argmax_along(Axis(0)).unwrap();
789            assert_eq!(result, array![1, 0, 1]);
790        }
791
792        #[test]
793        fn test_argmin_along_axis0() {
794            let a = array![[1.0, 4.2, 2.1], [3.3, -1.5, 7.7]];
795            let result = a.argmin_along(Axis(0)).unwrap();
796            assert_eq!(result, array![0, 1, 0]);
797        }
798
799        #[test]
800        fn test_argmax_along_axis1() {
801            let a = array![[1.0, 4.2, 2.1], [3.3, -1.5, 7.7]];
802            let result = a.argmax_along(Axis(1)).unwrap();
803            assert_eq!(result, array![1, 2]);
804        }
805
806        #[test]
807        fn test_argmin_along_axis1() {
808            let a = array![[1.0, 4.2, 2.1], [3.3, -1.5, 7.7]];
809            let result = a.argmin_along(Axis(1)).unwrap();
810            assert_eq!(result, array![0, 1]);
811        }
812
813        #[test]
814        fn test_maxval_minval_with_nan() {
815            let a = array![1.0, f64::NAN, 3.5];
816            // Now should return UndefinedOrder error when NaN is present
817            assert_eq!(a.maxval(), Err(ExtremaError::UndefinedOrder));
818            assert_eq!(a.minval(), Err(ExtremaError::UndefinedOrder));
819        }
820
821        #[test]
822        fn test_maxval_along_axis_with_nan() {
823            let a = array![[1.0, f64::NAN, 2.0], [3.0, 4.0, 5.0]];
824
825            // Along axis 0: column 1 has a NaN, so should return UndefinedOrder error
826            assert_eq!(a.maxval_along(Axis(0)), Err(ExtremaError::UndefinedOrder));
827
828            // Along axis 1: first row has a NaN, so should return UndefinedOrder error
829            assert_eq!(a.maxval_along(Axis(1)), Err(ExtremaError::UndefinedOrder));
830        }
831
832        #[test]
833        fn test_minval_along_axis_with_nan() {
834            let a = array![[1.0, 4.0, 2.0], [f64::NAN, -1.5, 7.0]];
835
836            // Both axes should return UndefinedOrder error due to NaN presence
837            assert_eq!(a.minval_along(Axis(0)), Err(ExtremaError::UndefinedOrder));
838            assert_eq!(a.minval_along(Axis(1)), Err(ExtremaError::UndefinedOrder));
839        }
840
841        #[test]
842        fn test_argmax_argmin_with_nan() {
843            let a = array![1.0, f64::NAN, 3.5];
844            // Should return UndefinedOrder error when NaN is present
845            assert_eq!(a.argmax(), Err(ExtremaError::UndefinedOrder));
846            assert_eq!(a.argmin(), Err(ExtremaError::UndefinedOrder));
847        }
848
849        #[test]
850        fn test_argmax_argmin_along_with_nan() {
851            let a = array![[1.0, f64::NAN, 2.0], [3.0, 4.0, 5.0]];
852
853            // Should return UndefinedOrder error due to NaN presence
854            assert_eq!(a.argmax_along(Axis(0)), Err(ExtremaError::UndefinedOrder));
855            assert_eq!(a.argmin_along(Axis(0)), Err(ExtremaError::UndefinedOrder));
856            assert_eq!(a.argmax_along(Axis(1)), Err(ExtremaError::UndefinedOrder));
857            assert_eq!(a.argmin_along(Axis(1)), Err(ExtremaError::UndefinedOrder));
858        }
859
860        #[test]
861        fn test_all_methods_empty_2d() {
862            let a: Array2<f64> = Array2::from_shape_vec((0, 3), vec![]).unwrap();
863            assert_eq!(a.maxval(), Err(ExtremaError::EmptyArray));
864            assert_eq!(a.minval(), Err(ExtremaError::EmptyArray));
865            assert_eq!(a.maxval_along(Axis(0)), Err(ExtremaError::EmptyArray));
866            assert_eq!(a.minval_along(Axis(1)), Err(ExtremaError::EmptyArray));
867            assert_eq!(a.argmax(), Err(ExtremaError::EmptyArray));
868            assert_eq!(a.argmin(), Err(ExtremaError::EmptyArray));
869            assert_eq!(a.argmax_along(Axis(0)), Err(ExtremaError::EmptyArray));
870            assert_eq!(a.argmin_along(Axis(1)), Err(ExtremaError::EmptyArray));
871        }
872
873        #[test]
874        fn test_valid_arrays_without_nan() {
875            // Test that normal arrays (without NaN) work correctly
876            let a = array![1, 5, 3, 2, 4];
877            assert_eq!(a.maxval(), Ok(5));
878            assert_eq!(a.minval(), Ok(1));
879            assert_eq!(a.argmax(), Ok(1));
880            assert_eq!(a.argmin(), Ok(0));
881
882            let b = array![[1, 2, 3], [4, 5, 6]];
883            assert_eq!(b.maxval_along(Axis(0)).unwrap(), array![4, 5, 6]);
884            assert_eq!(b.minval_along(Axis(0)).unwrap(), array![1, 2, 3]);
885            assert_eq!(b.argmax_along(Axis(1)).unwrap(), array![2, 2]);
886            assert_eq!(b.argmin_along(Axis(1)).unwrap(), array![0, 0]);
887        }
888
889        #[test]
890        fn test_single_element_arrays() {
891            let a = array![42.0];
892            assert_eq!(a.maxval(), Ok(42.0));
893            assert_eq!(a.minval(), Ok(42.0));
894            assert_eq!(a.argmax(), Ok(0));
895            assert_eq!(a.argmin(), Ok(0));
896
897            let b = array![[5.0]];
898            assert_eq!(b.maxval_along(Axis(0)).unwrap(), array![5.0]);
899            assert_eq!(b.minval_along(Axis(1)).unwrap(), array![5.0]);
900            assert_eq!(b.argmax_along(Axis(0)).unwrap(), array![0]);
901            assert_eq!(b.argmin_along(Axis(1)).unwrap(), array![0]);
902        }
903
904        #[test]
905        fn test_single_nan_element() {
906            let a = array![f64::NAN];
907            assert_eq!(a.maxval(), Err(ExtremaError::UndefinedOrder));
908            assert_eq!(a.minval(), Err(ExtremaError::UndefinedOrder));
909            assert_eq!(a.argmax(), Err(ExtremaError::UndefinedOrder));
910            assert_eq!(a.argmin(), Err(ExtremaError::UndefinedOrder));
911        }
912
913        #[test]
914        fn test_increasing_true() {
915            let a = array![1.0, 2.0, 2.0, 5.0];
916            assert!(a.is_monotonically_increasing().unwrap());
917            assert!(!a.is_strictly_increasing().unwrap());
918        }
919
920        #[test]
921        fn test_increasing_false() {
922            let a = array![1.0, 3.0, 2.0];
923            assert!(!a.is_monotonically_increasing().unwrap());
924            assert!(!a.is_strictly_increasing().unwrap());
925        }
926
927        #[test]
928        fn test_strictly_increasing_true() {
929            let a = array![1.0, 2.0, 3.0];
930            assert!(a.is_strictly_increasing().unwrap());
931            assert!(a.is_monotonically_increasing().unwrap());
932        }
933
934        #[test]
935        fn test_decreasing_true() {
936            let a = array![5.0, 4.0, 4.0, 1.0];
937            assert!(a.is_monotonically_decreasing().unwrap());
938            assert!(!a.is_strictly_decreasing().unwrap());
939        }
940
941        #[test]
942        fn test_decreasing_false() {
943            let a = array![5.0, 3.0, 4.0];
944            assert!(!a.is_monotonically_decreasing().unwrap());
945            assert!(!a.is_strictly_decreasing().unwrap());
946        }
947
948        #[test]
949        fn test_strictly_decreasing_true() {
950            let a = array![5.0, 3.0, 1.0];
951            assert!(a.is_strictly_decreasing().unwrap());
952            assert!(a.is_monotonically_decreasing().unwrap());
953        }
954
955        #[test]
956        fn test_empty_array() {
957            let a = Array1::<f64>::zeros(0);
958            assert!(matches!(
959                a.is_monotonically_increasing(),
960                Err(ExtremaError::EmptyArray)
961            ));
962            assert!(matches!(
963                a.is_monotonically_decreasing(),
964                Err(ExtremaError::EmptyArray)
965            ));
966            assert!(matches!(
967                a.is_strictly_increasing(),
968                Err(ExtremaError::EmptyArray)
969            ));
970            assert!(matches!(
971                a.is_strictly_decreasing(),
972                Err(ExtremaError::EmptyArray)
973            ));
974        }
975
976        #[test]
977        fn test_nan_propagates_as_undefined_order() {
978            let a = array![1.0, f64::NAN, 2.0];
979
980            assert!(matches!(
981                a.is_monotonically_increasing(),
982                Err(ExtremaError::UndefinedOrder)
983            ));
984            assert!(matches!(
985                a.is_monotonically_decreasing(),
986                Err(ExtremaError::UndefinedOrder)
987            ));
988            assert!(matches!(
989                a.is_strictly_increasing(),
990                Err(ExtremaError::UndefinedOrder)
991            ));
992            assert!(matches!(
993                a.is_strictly_decreasing(),
994                Err(ExtremaError::UndefinedOrder)
995            ));
996        }
997
998        #[test]
999        fn test_nan_as_first_element() {
1000            let a = array![f64::NAN, 1.0];
1001
1002            assert!(matches!(
1003                a.is_monotonically_increasing(),
1004                Err(ExtremaError::UndefinedOrder)
1005            ));
1006        }
1007    }
1008}
1009
1010pub mod types {
1011    use std::iter::Sum;
1012    use std::ops::{Add, AddAssign, Mul, Sub, SubAssign};
1013
1014    use ndarray::{Array1, Array2, Array3, Array4, ArrayView1};
1015
1016    /// Generic Vector (1D array)
1017    pub type Vector<T> = Array1<T>;
1018
1019    /// n-dimensional real vector (1D array).
1020    pub type RVector = Array1<f64>;
1021
1022    /// n-dimensional real vector view (1D view).
1023    pub type RVecView<'a> = ArrayView1<'a, f64>;
1024
1025    /// Generic matrix (2D array)
1026    pub type Matrix<T> = Array2<T>;
1027
1028    /// 2-dimensional unsigned-integer matrix.
1029    pub type UMatrix = Array2<usize>;
1030
1031    /// A real matrix (2D ndarray).
1032    pub type RMatrix = Array2<f64>;
1033
1034    /// n-dimensional real matrix view (2D view)
1035    pub type RMatView<'a> = ndarray::ArrayView2<'a, f64>;
1036
1037    /// Generic tensor (3D array)
1038    pub type Tensor<T> = Array3<T>;
1039
1040    /// A real tensor (3D ndarray).
1041    pub type RTensor = Array3<f64>;
1042
1043    /// A 4-dimensional real tensor (4D ndarray).
1044    pub type RTensor4 = Array4<f64>;
1045
1046    /// 1-dimensional unsigned-integer vector.
1047    pub type UVector = Array1<usize>;
1048
1049    /// 1-dimensional signed-integer vector.
1050    pub type IVector = Array1<isize>;
1051
1052    /// 1-dimensional boolean vector.
1053    pub type BVector = Array1<bool>;
1054
1055    #[cfg(feature = "complex")]
1056    use num_complex::Complex;
1057
1058    #[cfg(feature = "complex")]
1059    /// A fixed-length array of complex ([f64]) numbers.
1060    pub type CVector = Array1<Complex<f64>>;
1061
1062    #[cfg(feature = "complex")]
1063    /// A 2-dimensional array (matrix) of complex ([f64]) numbers.
1064    pub type CMatrix = Array2<Complex<f64>>;
1065
1066    #[cfg(feature = "complex")]
1067    /// A 3-dimensional array (tensor) of complex ([f64]) numbers.
1068    pub type CTensor = Array3<Complex<f64>>;
1069
1070    #[derive(Debug, Default, Clone, Copy)]
1071    pub struct Vec3 {
1072        pub x: f64,
1073        pub y: f64,
1074        pub z: f64,
1075    }
1076
1077    impl Vec3 {
1078        pub fn new(x: f64, y: f64, z: f64) -> Self {
1079            Self { x, y, z }
1080        }
1081
1082        pub fn to_array(&self) -> RVector {
1083            ndarray::array![self.x, self.y, self.z]
1084        }
1085
1086        pub fn from_array(arr: &RVector) -> Self {
1087            assert_eq!(arr.len(), 3, "Array must have exactly 3 elements");
1088            Self {
1089                x: arr[0],
1090                y: arr[1],
1091                z: arr[2],
1092            }
1093        }
1094
1095        pub fn set(&mut self, x: f64, y: f64, z: f64) {
1096            self.x = x;
1097            self.y = y;
1098            self.z = z;
1099        }
1100
1101        pub fn zero() -> Self {
1102            Self {
1103                x: 0.0,
1104                y: 0.0,
1105                z: 0.0,
1106            }
1107        }
1108
1109        pub fn norm(&self) -> f64 {
1110            (self.x * self.x + self.y * self.y + self.z * self.z).sqrt()
1111        }
1112
1113        pub fn dot(&self, other: &Vec3) -> f64 {
1114            self.x * other.x + self.y * other.y + self.z * other.z
1115        }
1116    }
1117
1118    // Vector * scalar
1119    impl Mul<f64> for Vec3 {
1120        type Output = Vec3;
1121
1122        fn mul(self, rhs: f64) -> Vec3 {
1123            Vec3 {
1124                x: self.x * rhs,
1125                y: self.y * rhs,
1126                z: self.z * rhs,
1127            }
1128        }
1129    }
1130
1131    // scalar * Vector (optional but often useful)
1132    impl Mul<Vec3> for f64 {
1133        type Output = Vec3;
1134
1135        fn mul(self, rhs: Vec3) -> Vec3 {
1136            rhs * self
1137        }
1138    }
1139
1140    // Vector + Vector
1141    impl Add for Vec3 {
1142        type Output = Vec3;
1143
1144        fn add(self, rhs: Vec3) -> Vec3 {
1145            Vec3 {
1146                x: self.x + rhs.x,
1147                y: self.y + rhs.y,
1148                z: self.z + rhs.z,
1149            }
1150        }
1151    }
1152
1153    impl AddAssign for Vec3 {
1154        fn add_assign(&mut self, rhs: Vec3) {
1155            self.x += rhs.x;
1156            self.y += rhs.y;
1157            self.z += rhs.z;
1158        }
1159    }
1160
1161    impl Sub for Vec3 {
1162        type Output = Vec3;
1163
1164        fn sub(self, rhs: Vec3) -> Vec3 {
1165            Vec3 {
1166                x: self.x - rhs.x,
1167                y: self.y - rhs.y,
1168                z: self.z - rhs.z,
1169            }
1170        }
1171    }
1172
1173    impl SubAssign for Vec3 {
1174        fn sub_assign(&mut self, rhs: Vec3) {
1175            self.x -= rhs.x;
1176            self.y -= rhs.y;
1177            self.z -= rhs.z;
1178        }
1179    }
1180
1181    // Enable iterator .sum()
1182    impl Sum for Vec3 {
1183        fn sum<I: Iterator<Item = Vec3>>(iter: I) -> Vec3 {
1184            iter.fold(Vec3::zero(), |a, b| a + b)
1185        }
1186    }
1187}