Skip to main content

use_numerical/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4//! Small approximate numerical helpers for `RustUse`.
5
6#[doc(inline)]
7pub use approx::{approx_eq, clamp_epsilon, relative_eq};
8#[doc(inline)]
9pub use difference::{backward_difference, central_difference, forward_difference};
10#[doc(inline)]
11pub use integration::{midpoint_rule, rectangle_rule, simpsons_rule, trapezoidal_rule};
12#[cfg(feature = "interval")]
13#[doc(inline)]
14pub use root::bisection_interval;
15#[doc(inline)]
16pub use root::{RootError, RootOptions, bisection, newton_raphson};
17
18/// Floating-point comparison helpers for approximate numerical work.
19pub mod approx {
20    /// Normalizes an epsilon to a non-negative finite value.
21    ///
22    /// Negative finite epsilons are converted to their absolute value. `NaN`
23    /// becomes `0.0`, and infinite values clamp to [`f64::MAX`].
24    #[must_use]
25    pub fn clamp_epsilon(epsilon: f64) -> f64 {
26        if epsilon.is_nan() {
27            0.0
28        } else if epsilon.is_infinite() {
29            f64::MAX
30        } else {
31            epsilon.abs()
32        }
33    }
34
35    /// Returns `true` when `a` and `b` are within an absolute epsilon.
36    ///
37    /// `NaN` never compares equal. Equal infinities compare equal. Mixed
38    /// finite and non-finite values compare unequal.
39    #[must_use]
40    pub fn approx_eq(a: f64, b: f64, epsilon: f64) -> bool {
41        if a.is_nan() || b.is_nan() {
42            return false;
43        }
44
45        if a == b {
46            return true;
47        }
48
49        if !a.is_finite() || !b.is_finite() {
50            return false;
51        }
52
53        (a - b).abs() <= clamp_epsilon(epsilon)
54    }
55
56    /// Returns `true` when `a` and `b` are within a relative epsilon.
57    ///
58    /// Near zero, this falls back to an absolute comparison by using a scale
59    /// factor of at least `1.0`. `NaN` never compares equal. Equal infinities
60    /// compare equal.
61    #[must_use]
62    pub fn relative_eq(a: f64, b: f64, epsilon: f64) -> bool {
63        if a.is_nan() || b.is_nan() {
64            return false;
65        }
66
67        if a == b {
68            return true;
69        }
70
71        if !a.is_finite() || !b.is_finite() {
72            return false;
73        }
74
75        let epsilon = clamp_epsilon(epsilon);
76        let scale = a.abs().max(b.abs()).max(1.0);
77
78        (a - b).abs() <= epsilon * scale
79    }
80}
81
82/// Finite-difference helpers for first-derivative approximation.
83pub mod difference {
84    /// Approximates the first derivative with a forward difference.
85    ///
86    /// This computes `(f(x + h) - f(x)) / h`. When `h == 0.0`, the result may
87    /// be infinite or `NaN` according to normal floating-point behavior.
88    #[must_use]
89    pub fn forward_difference<F>(f: F, x: f64, h: f64) -> f64
90    where
91        F: Fn(f64) -> f64,
92    {
93        (f(x + h) - f(x)) / h
94    }
95
96    /// Approximates the first derivative with a backward difference.
97    ///
98    /// This computes `(f(x) - f(x - h)) / h`. When `h == 0.0`, the result may
99    /// be infinite or `NaN` according to normal floating-point behavior.
100    #[must_use]
101    pub fn backward_difference<F>(f: F, x: f64, h: f64) -> f64
102    where
103        F: Fn(f64) -> f64,
104    {
105        (f(x) - f(x - h)) / h
106    }
107
108    /// Approximates the first derivative with a central difference.
109    ///
110    /// This computes `(f(x + h) - f(x - h)) / (2h)`. When `h == 0.0`, the
111    /// result may be infinite or `NaN` according to normal floating-point
112    /// behavior.
113    #[must_use]
114    pub fn central_difference<F>(f: F, x: f64, h: f64) -> f64
115    where
116        F: Fn(f64) -> f64,
117    {
118        (f(x + h) - f(x - h)) / (2.0 * h)
119    }
120}
121
122/// Deterministic numerical integration rules over `f64` intervals.
123pub mod integration {
124    fn normalized_bounds(a: f64, b: f64) -> Option<(f64, f64, f64)> {
125        if !a.is_finite() || !b.is_finite() {
126            return None;
127        }
128
129        if a <= b {
130            Some((a, b, 1.0))
131        } else {
132            Some((b, a, -1.0))
133        }
134    }
135
136    /// Approximates an integral with the left-rectangle rule.
137    ///
138    /// Returns `None` when `n == 0`, when the bounds are not finite, or when a
139    /// sampled function value is not finite. Reversed bounds return the
140    /// negative integral.
141    #[must_use]
142    pub fn rectangle_rule<F>(f: F, a: f64, b: f64, n: usize) -> Option<f64>
143    where
144        F: Fn(f64) -> f64,
145    {
146        if n == 0 {
147            return None;
148        }
149
150        let (start, end, sign) = normalized_bounds(a, b)?;
151        let step = (end - start) / n as f64;
152        let mut sum = 0.0;
153
154        for index in 0..n {
155            let x = start + index as f64 * step;
156            let value = f(x);
157            if !value.is_finite() {
158                return None;
159            }
160
161            sum += value;
162        }
163
164        Some(sign * sum * step)
165    }
166
167    /// Approximates an integral with the midpoint rule.
168    ///
169    /// Returns `None` when `n == 0`, when the bounds are not finite, or when a
170    /// sampled function value is not finite. Reversed bounds return the
171    /// negative integral.
172    #[must_use]
173    pub fn midpoint_rule<F>(f: F, a: f64, b: f64, n: usize) -> Option<f64>
174    where
175        F: Fn(f64) -> f64,
176    {
177        if n == 0 {
178            return None;
179        }
180
181        let (start, end, sign) = normalized_bounds(a, b)?;
182        let step = (end - start) / n as f64;
183        let mut sum = 0.0;
184
185        for index in 0..n {
186            let x = start + (index as f64 + 0.5) * step;
187            let value = f(x);
188            if !value.is_finite() {
189                return None;
190            }
191
192            sum += value;
193        }
194
195        Some(sign * sum * step)
196    }
197
198    /// Approximates an integral with the trapezoidal rule.
199    ///
200    /// Returns `None` when `n == 0`, when the bounds are not finite, or when a
201    /// sampled function value is not finite. Reversed bounds return the
202    /// negative integral.
203    #[must_use]
204    pub fn trapezoidal_rule<F>(f: F, a: f64, b: f64, n: usize) -> Option<f64>
205    where
206        F: Fn(f64) -> f64,
207    {
208        if n == 0 {
209            return None;
210        }
211
212        let (start, end, sign) = normalized_bounds(a, b)?;
213        let step = (end - start) / n as f64;
214        let start_value = f(start);
215        let end_value = f(end);
216
217        if !start_value.is_finite() || !end_value.is_finite() {
218            return None;
219        }
220
221        let mut sum = 0.5 * (start_value + end_value);
222
223        for index in 1..n {
224            let x = start + index as f64 * step;
225            let value = f(x);
226            if !value.is_finite() {
227                return None;
228            }
229
230            sum += value;
231        }
232
233        Some(sign * sum * step)
234    }
235
236    /// Approximates an integral with Simpson's rule.
237    ///
238    /// Returns `None` when `n == 0`, when `n` is odd, when the bounds are not
239    /// finite, or when a sampled function value is not finite. Reversed bounds
240    /// return the negative integral.
241    #[must_use]
242    pub fn simpsons_rule<F>(f: F, a: f64, b: f64, n: usize) -> Option<f64>
243    where
244        F: Fn(f64) -> f64,
245    {
246        if n == 0 || n % 2 != 0 {
247            return None;
248        }
249
250        let (start, end, sign) = normalized_bounds(a, b)?;
251        let step = (end - start) / n as f64;
252        let start_value = f(start);
253        let end_value = f(end);
254
255        if !start_value.is_finite() || !end_value.is_finite() {
256            return None;
257        }
258
259        let mut sum = start_value + end_value;
260
261        for index in 1..n {
262            let x = start + index as f64 * step;
263            let value = f(x);
264            if !value.is_finite() {
265                return None;
266            }
267
268            if index % 2 == 0 {
269                sum += 2.0 * value;
270            } else {
271                sum += 4.0 * value;
272            }
273        }
274
275        Some(sign * sum * step / 3.0)
276    }
277}
278
279/// Iterative approximate root-finding helpers.
280pub mod root {
281    use crate::approx::approx_eq;
282
283    #[cfg(feature = "interval")]
284    use use_interval::{Bound, Interval};
285
286    /// Configuration for iterative approximate root finders.
287    #[derive(Debug, Clone, Copy, PartialEq)]
288    pub struct RootOptions {
289        /// Absolute tolerance used for convergence and zero checks.
290        pub tolerance: f64,
291        /// Maximum number of solver iterations before returning an error.
292        pub max_iterations: usize,
293    }
294
295    impl Default for RootOptions {
296        fn default() -> Self {
297            Self {
298                tolerance: 1e-10,
299                max_iterations: 100,
300            }
301        }
302    }
303
304    /// Failure modes for approximate iterative root finders.
305    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
306    pub enum RootError {
307        /// The initial interval does not bracket a root.
308        InvalidInterval,
309        /// The configured tolerance is not finite and positive.
310        InvalidTolerance,
311        /// The solver did not converge within the iteration budget.
312        MaxIterationsReached,
313        /// A function evaluation, derivative evaluation, or iterate became non-finite.
314        NonFiniteValue,
315        /// A Newton-Raphson step encountered an approximately zero derivative.
316        ZeroDerivative,
317    }
318
319    fn validate_options(options: RootOptions) -> Result<(), RootError> {
320        if !options.tolerance.is_finite() || options.tolerance <= 0.0 {
321            Err(RootError::InvalidTolerance)
322        } else {
323            Ok(())
324        }
325    }
326
327    fn same_sign(left: f64, right: f64) -> bool {
328        (left < 0.0 && right < 0.0) || (left > 0.0 && right > 0.0)
329    }
330
331    fn bisection_with_policy<F>(
332        f: &F,
333        lower: f64,
334        upper: f64,
335        options: RootOptions,
336        allow_lower_endpoint: bool,
337        allow_upper_endpoint: bool,
338    ) -> Result<f64, RootError>
339    where
340        F: Fn(f64) -> f64,
341    {
342        validate_options(options)?;
343
344        if !lower.is_finite() || !upper.is_finite() || lower > upper {
345            return Err(RootError::InvalidInterval);
346        }
347
348        let lower_value = f(lower);
349        let upper_value = f(upper);
350
351        if !lower_value.is_finite() || !upper_value.is_finite() {
352            return Err(RootError::NonFiniteValue);
353        }
354
355        if approx_eq(lower_value, 0.0, options.tolerance) {
356            return if allow_lower_endpoint {
357                Ok(lower)
358            } else {
359                Err(RootError::InvalidInterval)
360            };
361        }
362
363        if approx_eq(upper_value, 0.0, options.tolerance) {
364            return if allow_upper_endpoint {
365                Ok(upper)
366            } else {
367                Err(RootError::InvalidInterval)
368            };
369        }
370
371        if same_sign(lower_value, upper_value) {
372            return Err(RootError::InvalidInterval);
373        }
374
375        let mut left = lower;
376        let mut right = upper;
377        let mut left_value = lower_value;
378
379        for _ in 0..options.max_iterations {
380            let midpoint = left + (right - left) * 0.5;
381            let midpoint_value = f(midpoint);
382
383            if !midpoint.is_finite() || !midpoint_value.is_finite() {
384                return Err(RootError::NonFiniteValue);
385            }
386
387            if approx_eq(midpoint_value, 0.0, options.tolerance)
388                || (right - left).abs() * 0.5 <= options.tolerance
389            {
390                return Ok(midpoint);
391            }
392
393            if same_sign(left_value, midpoint_value) {
394                left = midpoint;
395                left_value = midpoint_value;
396            } else {
397                right = midpoint;
398            }
399        }
400
401        Err(RootError::MaxIterationsReached)
402    }
403
404    /// Finds a root with the bisection method on a bracketing interval.
405    ///
406    /// The endpoint values must have opposite signs, unless an endpoint is
407    /// already approximately zero. The interval bounds and function values must
408    /// be finite. Convergence uses [`RootOptions::tolerance`] as an absolute
409    /// tolerance.
410    pub fn bisection<F>(
411        f: F,
412        lower: f64,
413        upper: f64,
414        options: RootOptions,
415    ) -> Result<f64, RootError>
416    where
417        F: Fn(f64) -> f64,
418    {
419        bisection_with_policy(&f, lower, upper, options, true, true)
420    }
421
422    /// Finds a root with Newton-Raphson iteration.
423    ///
424    /// This is an approximate iterative solver, not an exact equation helper.
425    /// The tolerance must be finite and positive. Non-finite values and
426    /// approximately zero derivatives return explicit errors.
427    pub fn newton_raphson<F, D>(
428        f: F,
429        derivative: D,
430        initial: f64,
431        options: RootOptions,
432    ) -> Result<f64, RootError>
433    where
434        F: Fn(f64) -> f64,
435        D: Fn(f64) -> f64,
436    {
437        validate_options(options)?;
438
439        if !initial.is_finite() {
440            return Err(RootError::NonFiniteValue);
441        }
442
443        let mut current = initial;
444
445        for _ in 0..options.max_iterations {
446            let value = f(current);
447            if !value.is_finite() {
448                return Err(RootError::NonFiniteValue);
449            }
450
451            if approx_eq(value, 0.0, options.tolerance) {
452                return Ok(current);
453            }
454
455            let slope = derivative(current);
456            if !slope.is_finite() {
457                return Err(RootError::NonFiniteValue);
458            }
459
460            if approx_eq(slope, 0.0, options.tolerance) {
461                return Err(RootError::ZeroDerivative);
462            }
463
464            let next = current - value / slope;
465            if !next.is_finite() {
466                return Err(RootError::NonFiniteValue);
467            }
468
469            if (next - current).abs() <= options.tolerance {
470                return Ok(next);
471            }
472
473            current = next;
474        }
475
476        Err(RootError::MaxIterationsReached)
477    }
478
479    /// Finds a root with the bisection method over a bounded `use-interval` interval.
480    ///
481    /// Unbounded and empty intervals return [`RootError::InvalidInterval`].
482    /// Open endpoints participate in bracketing, but only closed endpoints are
483    /// eligible for the immediate endpoint-root fast path.
484    #[cfg(feature = "interval")]
485    pub fn bisection_interval<F>(
486        f: F,
487        interval: Interval<f64>,
488        options: RootOptions,
489    ) -> Result<f64, RootError>
490    where
491        F: Fn(f64) -> f64,
492    {
493        if interval.is_empty() {
494            return Err(RootError::InvalidInterval);
495        }
496
497        let (lower, allow_lower_endpoint) = match interval.lower {
498            Bound::Open(value) => (value, false),
499            Bound::Closed(value) => (value, true),
500            Bound::Unbounded => return Err(RootError::InvalidInterval),
501        };
502
503        let (upper, allow_upper_endpoint) = match interval.upper {
504            Bound::Open(value) => (value, false),
505            Bound::Closed(value) => (value, true),
506            Bound::Unbounded => return Err(RootError::InvalidInterval),
507        };
508
509        bisection_with_policy(
510            &f,
511            lower,
512            upper,
513            options,
514            allow_lower_endpoint,
515            allow_upper_endpoint,
516        )
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::{
523        RootError, RootOptions, approx_eq, backward_difference, bisection, central_difference,
524        clamp_epsilon, forward_difference, midpoint_rule, newton_raphson, rectangle_rule,
525        relative_eq, simpsons_rule, trapezoidal_rule,
526    };
527
528    #[cfg(feature = "interval")]
529    use super::bisection_interval;
530
531    #[cfg(feature = "interval")]
532    use use_interval::Interval;
533
534    #[test]
535    fn absolute_approximation_equality_uses_absolute_difference() {
536        assert!(approx_eq(1.0, 1.0 + 5.0e-7, 1.0e-6));
537        assert!(!approx_eq(1.0, 1.0 + 2.0e-6, 1.0e-6));
538    }
539
540    #[test]
541    fn relative_approximation_equality_uses_relative_difference() {
542        assert!(relative_eq(10_000.0, 10_000.01, 1.0e-6));
543        assert!(!relative_eq(10_000.0, 10_000.5, 1.0e-6));
544    }
545
546    #[test]
547    fn zero_and_near_zero_comparisons_use_safe_fallback() {
548        assert!(relative_eq(1.0e-12, 0.0, 1.0e-9));
549        assert!(!relative_eq(1.0e-3, 0.0, 1.0e-6));
550    }
551
552    #[test]
553    fn negative_epsilon_behavior_is_clamped_to_positive() {
554        assert_eq!(clamp_epsilon(-1.0e-6), 1.0e-6);
555        assert!(approx_eq(2.0, 2.0 + 5.0e-7, -1.0e-6));
556    }
557
558    #[test]
559    fn non_finite_comparison_behavior_is_explicit() {
560        assert!(!approx_eq(f64::NAN, f64::NAN, 1.0e-6));
561        assert!(!relative_eq(f64::NAN, 1.0, 1.0e-6));
562        assert!(approx_eq(f64::INFINITY, f64::INFINITY, 1.0e-6));
563        assert!(!approx_eq(f64::INFINITY, f64::NEG_INFINITY, 1.0e-6));
564        assert!(!relative_eq(f64::INFINITY, 1.0, 1.0e-6));
565    }
566
567    #[test]
568    fn forward_difference_approximates_first_derivative() {
569        let derivative = forward_difference(|x| x * x, 3.0, 1.0e-6);
570
571        assert!((derivative - 6.0).abs() < 1.0e-5);
572    }
573
574    #[test]
575    fn backward_difference_approximates_first_derivative() {
576        let derivative = backward_difference(|x| x * x, 3.0, 1.0e-6);
577
578        assert!((derivative - 6.0).abs() < 1.0e-5);
579    }
580
581    #[test]
582    fn central_difference_approximates_first_derivative() {
583        let derivative = central_difference(|x| x * x, 3.0, 1.0e-6);
584
585        assert!((derivative - 6.0).abs() < 1.0e-5);
586    }
587
588    #[test]
589    fn rectangle_rule_approximates_integrals() {
590        let area = rectangle_rule(|x| x * x, 0.0, 1.0, 10_000).unwrap();
591
592        assert!((area - 1.0 / 3.0).abs() < 1.0e-4);
593    }
594
595    #[test]
596    fn midpoint_rule_approximates_integrals() {
597        let area = midpoint_rule(|x| x * x, 0.0, 1.0, 1_000).unwrap();
598
599        assert!((area - 1.0 / 3.0).abs() < 1.0e-6);
600    }
601
602    #[test]
603    fn trapezoidal_rule_approximates_integrals() {
604        let area = trapezoidal_rule(|x| x * x, 0.0, 1.0, 1_000).unwrap();
605
606        assert!((area - 1.0 / 3.0).abs() < 1.0e-6);
607    }
608
609    #[test]
610    fn simpsons_rule_approximates_integrals() {
611        let area = simpsons_rule(|x| x * x, 0.0, 1.0, 10).unwrap();
612
613        assert!((area - 1.0 / 3.0).abs() < 1.0e-12);
614    }
615
616    #[test]
617    fn invalid_integration_subdivision_count_returns_none() {
618        assert_eq!(rectangle_rule(|x| x, 0.0, 1.0, 0), None);
619        assert_eq!(midpoint_rule(|x| x, 0.0, 1.0, 0), None);
620        assert_eq!(trapezoidal_rule(|x| x, 0.0, 1.0, 0), None);
621    }
622
623    #[test]
624    fn simpsons_rule_rejects_odd_subdivision_counts() {
625        assert_eq!(simpsons_rule(|x| x, 0.0, 1.0, 3), None);
626    }
627
628    #[test]
629    fn reversed_integration_bounds_return_negative_integral() {
630        let area = trapezoidal_rule(|x| x * x, 1.0, 0.0, 1_000).unwrap();
631
632        assert!((area + 1.0 / 3.0).abs() < 1.0e-6);
633    }
634
635    #[test]
636    fn bisection_succeeds_for_bracketed_root() {
637        let root = bisection(|x| x * x - 2.0, 1.0, 2.0, RootOptions::default()).unwrap();
638
639        assert!((root - 2.0_f64.sqrt()).abs() < 1.0e-8);
640    }
641
642    #[test]
643    fn bisection_returns_endpoint_root_when_present() {
644        let root = bisection(|x| x - 1.0, 1.0, 3.0, RootOptions::default()).unwrap();
645
646        assert_eq!(root, 1.0);
647    }
648
649    #[test]
650    fn bisection_rejects_invalid_intervals() {
651        assert_eq!(
652            bisection(|x| x * x + 1.0, -1.0, 1.0, RootOptions::default()),
653            Err(RootError::InvalidInterval)
654        );
655    }
656
657    #[test]
658    fn bisection_rejects_invalid_tolerance() {
659        assert_eq!(
660            bisection(
661                |x| x * x - 2.0,
662                1.0,
663                2.0,
664                RootOptions {
665                    tolerance: 0.0,
666                    max_iterations: 100,
667                },
668            ),
669            Err(RootError::InvalidTolerance)
670        );
671    }
672
673    #[test]
674    fn bisection_reports_max_iterations_reached() {
675        assert_eq!(
676            bisection(
677                |x| x * x - 2.0,
678                1.0,
679                2.0,
680                RootOptions {
681                    tolerance: 1.0e-20,
682                    max_iterations: 1,
683                },
684            ),
685            Err(RootError::MaxIterationsReached)
686        );
687    }
688
689    #[test]
690    fn newton_raphson_succeeds_for_simple_root() {
691        let root =
692            newton_raphson(|x| x * x - 2.0, |x| 2.0 * x, 1.0, RootOptions::default()).unwrap();
693
694        assert!((root - 2.0_f64.sqrt()).abs() < 1.0e-8);
695    }
696
697    #[test]
698    fn newton_raphson_reports_zero_derivative() {
699        assert_eq!(
700            newton_raphson(
701                |x| x * x * x + 1.0,
702                |x| 3.0 * x * x,
703                0.0,
704                RootOptions::default(),
705            ),
706            Err(RootError::ZeroDerivative)
707        );
708    }
709
710    #[test]
711    fn newton_raphson_rejects_invalid_tolerance() {
712        assert_eq!(
713            newton_raphson(
714                |x| x * x - 2.0,
715                |x| 2.0 * x,
716                1.0,
717                RootOptions {
718                    tolerance: 0.0,
719                    max_iterations: 100,
720                },
721            ),
722            Err(RootError::InvalidTolerance)
723        );
724    }
725
726    #[test]
727    fn newton_raphson_reports_max_iterations_reached() {
728        assert_eq!(
729            newton_raphson(
730                |x| x * x - 2.0,
731                |x| 2.0 * x,
732                1.0,
733                RootOptions {
734                    tolerance: 1.0e-20,
735                    max_iterations: 1,
736                },
737            ),
738            Err(RootError::MaxIterationsReached)
739        );
740    }
741
742    #[test]
743    fn non_finite_root_finding_behavior_is_reported() {
744        assert_eq!(
745            bisection(|_| f64::NAN, 0.0, 1.0, RootOptions::default()),
746            Err(RootError::NonFiniteValue)
747        );
748        assert_eq!(
749            newton_raphson(|_| f64::NAN, |_| 1.0, 0.0, RootOptions::default(),),
750            Err(RootError::NonFiniteValue)
751        );
752    }
753
754    #[cfg(feature = "interval")]
755    #[test]
756    fn bisection_interval_supports_bounded_intervals() {
757        let root = bisection_interval(
758            |x| x * x - 2.0,
759            Interval::closed(1.0, 2.0),
760            RootOptions::default(),
761        )
762        .unwrap();
763
764        assert!((root - 2.0_f64.sqrt()).abs() < 1.0e-8);
765    }
766
767    #[cfg(feature = "interval")]
768    #[test]
769    fn bisection_interval_does_not_accept_open_endpoint_roots() {
770        assert_eq!(
771            bisection_interval(|x| x, Interval::open(0.0, 2.0), RootOptions::default()),
772            Err(RootError::InvalidInterval)
773        );
774    }
775}