rustpython_stdlib/
math.rs

1pub(crate) use math::make_module;
2
3#[pymodule]
4mod math {
5    use crate::vm::{
6        builtins::{try_bigint_to_f64, try_f64_to_bigint, PyFloat, PyInt, PyIntRef, PyStrInterned},
7        function::{ArgIndex, ArgIntoFloat, ArgIterable, Either, OptionalArg, PosArgs},
8        identifier, PyObject, PyObjectRef, PyRef, PyResult, VirtualMachine,
9    };
10    use itertools::Itertools;
11    use malachite_bigint::BigInt;
12    use num_traits::{One, Signed, Zero};
13    use rustpython_common::{float_ops, int::true_div};
14    use std::cmp::Ordering;
15
16    // Constants
17    #[pyattr]
18    use std::f64::consts::{E as e, PI as pi, TAU as tau};
19    #[pyattr(name = "inf")]
20    const INF: f64 = f64::INFINITY;
21    #[pyattr(name = "nan")]
22    const NAN: f64 = f64::NAN;
23
24    // Helper macro:
25    macro_rules! call_math_func {
26        ( $fun:ident, $name:ident, $vm:ident ) => {{
27            let value = *$name;
28            let result = value.$fun();
29            result_or_overflow(value, result, $vm)
30        }};
31    }
32
33    #[inline]
34    fn result_or_overflow(value: f64, result: f64, vm: &VirtualMachine) -> PyResult<f64> {
35        if !result.is_finite() && value.is_finite() {
36            // CPython doesn't return `inf` when called with finite
37            // values, it raises OverflowError instead.
38            Err(vm.new_overflow_error("math range error".to_owned()))
39        } else {
40            Ok(result)
41        }
42    }
43
44    // Number theory functions:
45    #[pyfunction]
46    fn fabs(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
47        call_math_func!(abs, x, vm)
48    }
49
50    #[pyfunction]
51    fn isfinite(x: ArgIntoFloat) -> bool {
52        x.is_finite()
53    }
54
55    #[pyfunction]
56    fn isinf(x: ArgIntoFloat) -> bool {
57        x.is_infinite()
58    }
59
60    #[pyfunction]
61    fn isnan(x: ArgIntoFloat) -> bool {
62        x.is_nan()
63    }
64
65    #[derive(FromArgs)]
66    struct IsCloseArgs {
67        #[pyarg(positional)]
68        a: ArgIntoFloat,
69        #[pyarg(positional)]
70        b: ArgIntoFloat,
71        #[pyarg(named, optional)]
72        rel_tol: OptionalArg<ArgIntoFloat>,
73        #[pyarg(named, optional)]
74        abs_tol: OptionalArg<ArgIntoFloat>,
75    }
76
77    #[pyfunction]
78    fn isclose(args: IsCloseArgs, vm: &VirtualMachine) -> PyResult<bool> {
79        let a = *args.a;
80        let b = *args.b;
81        let rel_tol = args.rel_tol.map_or(1e-09, |value| value.into());
82        let abs_tol = args.abs_tol.map_or(0.0, |value| value.into());
83
84        if rel_tol < 0.0 || abs_tol < 0.0 {
85            return Err(vm.new_value_error("tolerances must be non-negative".to_owned()));
86        }
87
88        if a == b {
89            /* short circuit exact equality -- needed to catch two infinities of
90               the same sign. And perhaps speeds things up a bit sometimes.
91            */
92            return Ok(true);
93        }
94
95        /* This catches the case of two infinities of opposite sign, or
96           one infinity and one finite number. Two infinities of opposite
97           sign would otherwise have an infinite relative tolerance.
98           Two infinities of the same sign are caught by the equality check
99           above.
100        */
101
102        if a.is_infinite() || b.is_infinite() {
103            return Ok(false);
104        }
105
106        let diff = (b - a).abs();
107
108        Ok((diff <= (rel_tol * b).abs()) || (diff <= (rel_tol * a).abs()) || (diff <= abs_tol))
109    }
110
111    #[pyfunction]
112    fn copysign(x: ArgIntoFloat, y: ArgIntoFloat) -> f64 {
113        if x.is_nan() || y.is_nan() {
114            x.into()
115        } else {
116            x.copysign(*y)
117        }
118    }
119
120    // Power and logarithmic functions:
121    #[pyfunction]
122    fn exp(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
123        call_math_func!(exp, x, vm)
124    }
125
126    #[pyfunction]
127    fn exp2(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
128        call_math_func!(exp2, x, vm)
129    }
130
131    #[pyfunction]
132    fn expm1(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
133        call_math_func!(exp_m1, x, vm)
134    }
135
136    #[pyfunction]
137    fn log(x: PyObjectRef, base: OptionalArg<ArgIntoFloat>, vm: &VirtualMachine) -> PyResult<f64> {
138        let base = base.map(|b| *b).unwrap_or(std::f64::consts::E);
139        log2(x, vm).map(|logx| logx / base.log2())
140    }
141
142    #[pyfunction]
143    fn log1p(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
144        let x = *x;
145        if x.is_nan() || x > -1.0_f64 {
146            Ok((x + 1.0_f64).ln())
147        } else {
148            Err(vm.new_value_error("math domain error".to_owned()))
149        }
150    }
151
152    /// Generates the base-2 logarithm of a BigInt `x`
153    fn int_log2(x: &BigInt) -> f64 {
154        // log2(x) = log2(2^n * 2^-n * x) = n + log2(x/2^n)
155        // If we set 2^n to be the greatest power of 2 below x, then x/2^n is in [1, 2), and can
156        // thus be converted into a float.
157        let n = x.bits() as u32 - 1;
158        let frac = true_div(x, &BigInt::from(2).pow(n));
159        f64::from(n) + frac.log2()
160    }
161
162    #[pyfunction]
163    fn log2(x: PyObjectRef, vm: &VirtualMachine) -> PyResult<f64> {
164        match x.try_float(vm) {
165            Ok(x) => {
166                let x = x.to_f64();
167                if x.is_nan() || x > 0.0_f64 {
168                    Ok(x.log2())
169                } else {
170                    Err(vm.new_value_error("math domain error".to_owned()))
171                }
172            }
173            Err(float_err) => {
174                if let Ok(x) = x.try_int(vm) {
175                    let x = x.as_bigint();
176                    if x.is_positive() {
177                        Ok(int_log2(x))
178                    } else {
179                        Err(vm.new_value_error("math domain error".to_owned()))
180                    }
181                } else {
182                    // Return the float error, as it will be more intuitive to users
183                    Err(float_err)
184                }
185            }
186        }
187    }
188
189    #[pyfunction]
190    fn log10(x: PyObjectRef, vm: &VirtualMachine) -> PyResult<f64> {
191        log2(x, vm).map(|logx| logx / 10f64.log2())
192    }
193
194    #[pyfunction]
195    fn pow(x: ArgIntoFloat, y: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
196        let x = *x;
197        let y = *y;
198
199        if x < 0.0 && x.is_finite() && y.fract() != 0.0 && y.is_finite() {
200            return Err(vm.new_value_error("math domain error".to_owned()));
201        }
202
203        if x == 0.0 && y < 0.0 && y != f64::NEG_INFINITY {
204            return Err(vm.new_value_error("math domain error".to_owned()));
205        }
206
207        let value = x.powf(y);
208
209        Ok(value)
210    }
211
212    #[pyfunction]
213    fn sqrt(value: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
214        let value = *value;
215        if value.is_nan() {
216            return Ok(value);
217        }
218        if value.is_sign_negative() {
219            return Err(vm.new_value_error("math domain error".to_owned()));
220        }
221        Ok(value.sqrt())
222    }
223
224    #[pyfunction]
225    fn isqrt(x: ArgIndex, vm: &VirtualMachine) -> PyResult<BigInt> {
226        let value = x.as_bigint();
227
228        if value.is_negative() {
229            return Err(vm.new_value_error("isqrt() argument must be nonnegative".to_owned()));
230        }
231        Ok(value.sqrt())
232    }
233
234    // Trigonometric functions:
235    #[pyfunction]
236    fn acos(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
237        let x = *x;
238        if x.is_nan() || (-1.0_f64..=1.0_f64).contains(&x) {
239            Ok(x.acos())
240        } else {
241            Err(vm.new_value_error("math domain error".to_owned()))
242        }
243    }
244
245    #[pyfunction]
246    fn asin(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
247        let x = *x;
248        if x.is_nan() || (-1.0_f64..=1.0_f64).contains(&x) {
249            Ok(x.asin())
250        } else {
251            Err(vm.new_value_error("math domain error".to_owned()))
252        }
253    }
254
255    #[pyfunction]
256    fn atan(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
257        call_math_func!(atan, x, vm)
258    }
259
260    #[pyfunction]
261    fn atan2(y: ArgIntoFloat, x: ArgIntoFloat) -> f64 {
262        y.atan2(*x)
263    }
264
265    #[pyfunction]
266    fn cos(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
267        call_math_func!(cos, x, vm)
268    }
269
270    #[pyfunction]
271    fn hypot(coordinates: PosArgs<ArgIntoFloat>) -> f64 {
272        let mut coordinates = ArgIntoFloat::vec_into_f64(coordinates.into_vec());
273        let mut max = 0.0;
274        let mut has_nan = false;
275        for f in &mut coordinates {
276            *f = f.abs();
277            if f.is_nan() {
278                has_nan = true;
279            } else if *f > max {
280                max = *f
281            }
282        }
283        // inf takes precedence over nan
284        if max.is_infinite() {
285            return max;
286        }
287        if has_nan {
288            return f64::NAN;
289        }
290        coordinates.sort_unstable_by(|x, y| x.total_cmp(y).reverse());
291        vector_norm(&coordinates)
292    }
293
294    /// Implementation of accurate hypotenuse algorithm from Borges 2019.
295    /// See https://arxiv.org/abs/1904.09481.
296    /// This assumes that its arguments are positive finite and have been scaled to avoid overflow
297    /// and underflow.
298    fn accurate_hypot(max: f64, min: f64) -> f64 {
299        if min <= max * (f64::EPSILON / 2.0).sqrt() {
300            return max;
301        }
302        let hypot = max.mul_add(max, min * min).sqrt();
303        let hypot_sq = hypot * hypot;
304        let max_sq = max * max;
305        let correction = (-min).mul_add(min, hypot_sq - max_sq) + hypot.mul_add(hypot, -hypot_sq)
306            - max.mul_add(max, -max_sq);
307        hypot - correction / (2.0 * hypot)
308    }
309
310    /// Calculates the norm of the vector given by `v`.
311    /// `v` is assumed to be a list of non-negative finite floats, sorted in descending order.
312    fn vector_norm(v: &[f64]) -> f64 {
313        // Drop zeros from the vector.
314        let zero_count = v.iter().rev().cloned().take_while(|x| *x == 0.0).count();
315        let v = &v[..v.len() - zero_count];
316        if v.is_empty() {
317            return 0.0;
318        }
319        if v.len() == 1 {
320            return v[0];
321        }
322        // Calculate scaling to avoid overflow / underflow.
323        let max = *v.first().unwrap();
324        let min = *v.last().unwrap();
325        let scale = if max > (f64::MAX / v.len() as f64).sqrt() {
326            max
327        } else if min < f64::MIN_POSITIVE.sqrt() {
328            // ^ This can be an `else if`, because if the max is near f64::MAX and the min is near
329            // f64::MIN_POSITIVE, then the min is relatively unimportant and will be effectively
330            // ignored.
331            min
332        } else {
333            1.0
334        };
335        let mut norm = v
336            .iter()
337            .copied()
338            .map(|x| x / scale)
339            .reduce(accurate_hypot)
340            .unwrap_or_default();
341        if v.len() > 2 {
342            // For larger lists of numbers, we can accumulate a rounding error, so a correction is
343            // needed, similar to that in `accurate_hypot()`.
344            // First, we estimate [sum of squares - norm^2], then we add the first-order
345            // approximation of the square root of that to `norm`.
346            let correction = v
347                .iter()
348                .copied()
349                .map(|x| (x / scale).powi(2))
350                .chain(std::iter::once(-norm * norm))
351                // Pairwise summation of floats gives less rounding error than a naive sum.
352                .tree_fold1(std::ops::Add::add)
353                .expect("expected at least 1 element");
354            norm = norm + correction / (2.0 * norm);
355        }
356        norm * scale
357    }
358
359    #[pyfunction]
360    fn dist(p: Vec<ArgIntoFloat>, q: Vec<ArgIntoFloat>, vm: &VirtualMachine) -> PyResult<f64> {
361        let mut max = 0.0;
362        let mut has_nan = false;
363
364        let p = ArgIntoFloat::vec_into_f64(p);
365        let q = ArgIntoFloat::vec_into_f64(q);
366        let mut diffs = vec![];
367
368        if p.len() != q.len() {
369            return Err(vm.new_value_error(
370                "both points must have the same number of dimensions".to_owned(),
371            ));
372        }
373
374        for i in 0..p.len() {
375            let px = p[i];
376            let qx = q[i];
377
378            let x = (px - qx).abs();
379            if x.is_nan() {
380                has_nan = true;
381            }
382
383            diffs.push(x);
384            if x > max {
385                max = x;
386            }
387        }
388
389        if max.is_infinite() {
390            return Ok(max);
391        }
392        if has_nan {
393            return Ok(f64::NAN);
394        }
395        diffs.sort_unstable_by(|x, y| x.total_cmp(y).reverse());
396        Ok(vector_norm(&diffs))
397    }
398
399    #[pyfunction]
400    fn sin(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
401        call_math_func!(sin, x, vm)
402    }
403
404    #[pyfunction]
405    fn tan(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
406        call_math_func!(tan, x, vm)
407    }
408
409    #[pyfunction]
410    fn degrees(x: ArgIntoFloat) -> f64 {
411        *x * (180.0 / std::f64::consts::PI)
412    }
413
414    #[pyfunction]
415    fn radians(x: ArgIntoFloat) -> f64 {
416        *x * (std::f64::consts::PI / 180.0)
417    }
418
419    // Hyperbolic functions:
420
421    #[pyfunction]
422    fn acosh(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
423        let x = *x;
424        if x.is_sign_negative() || x.is_zero() {
425            Err(vm.new_value_error("math domain error".to_owned()))
426        } else {
427            Ok(x.acosh())
428        }
429    }
430
431    #[pyfunction]
432    fn asinh(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
433        call_math_func!(asinh, x, vm)
434    }
435
436    #[pyfunction]
437    fn atanh(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
438        let x = *x;
439        if x >= 1.0_f64 || x <= -1.0_f64 {
440            Err(vm.new_value_error("math domain error".to_owned()))
441        } else {
442            Ok(x.atanh())
443        }
444    }
445
446    #[pyfunction]
447    fn cosh(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
448        call_math_func!(cosh, x, vm)
449    }
450
451    #[pyfunction]
452    fn sinh(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
453        call_math_func!(sinh, x, vm)
454    }
455
456    #[pyfunction]
457    fn tanh(x: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
458        call_math_func!(tanh, x, vm)
459    }
460
461    // Special functions:
462    #[pyfunction]
463    fn erf(x: ArgIntoFloat) -> f64 {
464        let x = *x;
465        if x.is_nan() {
466            x
467        } else {
468            puruspe::erf(x)
469        }
470    }
471
472    #[pyfunction]
473    fn erfc(x: ArgIntoFloat) -> f64 {
474        let x = *x;
475        if x.is_nan() {
476            x
477        } else {
478            puruspe::erfc(x)
479        }
480    }
481
482    #[pyfunction]
483    fn gamma(x: ArgIntoFloat) -> f64 {
484        let x = *x;
485        if x.is_finite() {
486            puruspe::gamma(x)
487        } else if x.is_nan() || x.is_sign_positive() {
488            x
489        } else {
490            f64::NAN
491        }
492    }
493
494    #[pyfunction]
495    fn lgamma(x: ArgIntoFloat) -> f64 {
496        let x = *x;
497        if x.is_finite() {
498            puruspe::ln_gamma(x)
499        } else if x.is_nan() {
500            x
501        } else {
502            f64::INFINITY
503        }
504    }
505
506    fn try_magic_method(
507        func_name: &'static PyStrInterned,
508        vm: &VirtualMachine,
509        value: &PyObject,
510    ) -> PyResult {
511        let method = vm.get_method_or_type_error(value.to_owned(), func_name, || {
512            format!(
513                "type '{}' doesn't define '{}' method",
514                value.class().name(),
515                func_name.as_str(),
516            )
517        })?;
518        method.call((), vm)
519    }
520
521    #[pyfunction]
522    fn trunc(x: PyObjectRef, vm: &VirtualMachine) -> PyResult {
523        try_magic_method(identifier!(vm, __trunc__), vm, &x)
524    }
525
526    #[pyfunction]
527    fn ceil(x: PyObjectRef, vm: &VirtualMachine) -> PyResult {
528        let result_or_err = try_magic_method(identifier!(vm, __ceil__), vm, &x);
529        if result_or_err.is_err() {
530            if let Some(v) = x.try_float_opt(vm) {
531                let v = try_f64_to_bigint(v?.to_f64().ceil(), vm)?;
532                return Ok(vm.ctx.new_int(v).into());
533            }
534        }
535        result_or_err
536    }
537
538    #[pyfunction]
539    fn floor(x: PyObjectRef, vm: &VirtualMachine) -> PyResult {
540        let result_or_err = try_magic_method(identifier!(vm, __floor__), vm, &x);
541        if result_or_err.is_err() {
542            if let Some(v) = x.try_float_opt(vm) {
543                let v = try_f64_to_bigint(v?.to_f64().floor(), vm)?;
544                return Ok(vm.ctx.new_int(v).into());
545            }
546        }
547        result_or_err
548    }
549
550    #[pyfunction]
551    fn frexp(x: ArgIntoFloat) -> (f64, i32) {
552        let value = *x;
553        if value.is_finite() {
554            let (m, exp) = float_ops::ufrexp(value);
555            (m * value.signum(), exp)
556        } else {
557            (value, 0)
558        }
559    }
560
561    #[pyfunction]
562    fn ldexp(
563        x: Either<PyRef<PyFloat>, PyIntRef>,
564        i: PyIntRef,
565        vm: &VirtualMachine,
566    ) -> PyResult<f64> {
567        let value = match x {
568            Either::A(f) => f.to_f64(),
569            Either::B(z) => try_bigint_to_f64(z.as_bigint(), vm)?,
570        };
571
572        if value == 0_f64 || !value.is_finite() {
573            // NaNs, zeros and infinities are returned unchanged
574            Ok(value)
575        } else {
576            let result = value * (2_f64).powf(try_bigint_to_f64(i.as_bigint(), vm)?);
577            result_or_overflow(value, result, vm)
578        }
579    }
580
581    fn math_perf_arb_len_int_op<F>(args: PosArgs<ArgIndex>, op: F, default: BigInt) -> BigInt
582    where
583        F: Fn(&BigInt, &PyInt) -> BigInt,
584    {
585        let argvec = args.into_vec();
586
587        if argvec.is_empty() {
588            return default;
589        } else if argvec.len() == 1 {
590            return op(argvec[0].as_bigint(), &argvec[0]);
591        }
592
593        let mut res = argvec[0].as_bigint().clone();
594        for num in &argvec[1..] {
595            res = op(&res, num)
596        }
597        res
598    }
599
600    #[pyfunction]
601    fn gcd(args: PosArgs<ArgIndex>) -> BigInt {
602        use num_integer::Integer;
603        math_perf_arb_len_int_op(args, |x, y| x.gcd(y.as_bigint()), BigInt::zero())
604    }
605
606    #[pyfunction]
607    fn lcm(args: PosArgs<ArgIndex>) -> BigInt {
608        use num_integer::Integer;
609        math_perf_arb_len_int_op(args, |x, y| x.lcm(y.as_bigint()), BigInt::one())
610    }
611
612    #[pyfunction]
613    fn cbrt(x: ArgIntoFloat) -> f64 {
614        x.cbrt()
615    }
616
617    #[pyfunction]
618    fn fsum(seq: ArgIterable<ArgIntoFloat>, vm: &VirtualMachine) -> PyResult<f64> {
619        let mut partials = vec![];
620        let mut special_sum = 0.0;
621        let mut inf_sum = 0.0;
622
623        for obj in seq.iter(vm)? {
624            let mut x = *obj?;
625
626            let xsave = x;
627            let mut j = 0;
628            // This inner loop applies `hi`/`lo` summation to each
629            // partial so that the list of partial sums remains exact.
630            for i in 0..partials.len() {
631                let mut y: f64 = partials[i];
632                if x.abs() < y.abs() {
633                    std::mem::swap(&mut x, &mut y);
634                }
635                // Rounded `x+y` is stored in `hi` with round-off stored in
636                // `lo`. Together `hi+lo` are exactly equal to `x+y`.
637                let hi = x + y;
638                let lo = y - (hi - x);
639                if lo != 0.0 {
640                    partials[j] = lo;
641                    j += 1;
642                }
643                x = hi;
644            }
645
646            if !x.is_finite() {
647                // a nonfinite x could arise either as
648                // a result of intermediate overflow, or
649                // as a result of a nan or inf in the
650                // summands
651                if xsave.is_finite() {
652                    return Err(vm.new_overflow_error("intermediate overflow in fsum".to_owned()));
653                }
654                if xsave.is_infinite() {
655                    inf_sum += xsave;
656                }
657                special_sum += xsave;
658                // reset partials
659                partials.clear();
660            }
661
662            if j >= partials.len() {
663                partials.push(x);
664            } else {
665                partials[j] = x;
666                partials.truncate(j + 1);
667            }
668        }
669        if special_sum != 0.0 {
670            return if inf_sum.is_nan() {
671                Err(vm.new_value_error("-inf + inf in fsum".to_owned()))
672            } else {
673                Ok(special_sum)
674            };
675        }
676
677        let mut n = partials.len();
678        if n > 0 {
679            n -= 1;
680            let mut hi = partials[n];
681
682            let mut lo = 0.0;
683            while n > 0 {
684                let x = hi;
685
686                n -= 1;
687                let y = partials[n];
688
689                hi = x + y;
690                lo = y - (hi - x);
691                if lo != 0.0 {
692                    break;
693                }
694            }
695            if n > 0 && ((lo < 0.0 && partials[n - 1] < 0.0) || (lo > 0.0 && partials[n - 1] > 0.0))
696            {
697                let y = lo + lo;
698                let x = hi + y;
699
700                // Make half-even rounding work across multiple partials.
701                // Needed so that sum([1e-16, 1, 1e16]) will round-up the last
702                // digit to two instead of down to zero (the 1e-16 makes the 1
703                // slightly closer to two).  With a potential 1 ULP rounding
704                // error fixed-up, math.fsum() can guarantee commutativity.
705                if y == x - hi {
706                    hi = x;
707                }
708            }
709
710            Ok(hi)
711        } else {
712            Ok(0.0)
713        }
714    }
715
716    #[pyfunction]
717    fn factorial(x: PyIntRef, vm: &VirtualMachine) -> PyResult<BigInt> {
718        let value = x.as_bigint();
719        let one = BigInt::one();
720        if value.is_negative() {
721            return Err(
722                vm.new_value_error("factorial() not defined for negative values".to_owned())
723            );
724        } else if *value <= one {
725            return Ok(one);
726        }
727        // start from 2, since we know that value > 1 and 1*2=2
728        let mut current = one + 1;
729        let mut product = BigInt::from(2u8);
730        while current < *value {
731            current += 1;
732            product *= &current;
733        }
734        Ok(product)
735    }
736
737    #[pyfunction]
738    fn perm(
739        n: ArgIndex,
740        k: OptionalArg<Option<ArgIndex>>,
741        vm: &VirtualMachine,
742    ) -> PyResult<BigInt> {
743        let n = n.as_bigint();
744        let k_ref;
745        let v = match k.flatten() {
746            Some(k) => {
747                k_ref = k;
748                k_ref.as_bigint()
749            }
750            None => n,
751        };
752
753        if n.is_negative() || v.is_negative() {
754            return Err(vm.new_value_error("perm() not defined for negative values".to_owned()));
755        }
756        if v > n {
757            return Ok(BigInt::zero());
758        }
759        let mut result = BigInt::one();
760        let mut current = n.clone();
761        let tmp = n - v;
762        while current > tmp {
763            result *= &current;
764            current -= 1;
765        }
766        Ok(result)
767    }
768
769    #[pyfunction]
770    fn comb(n: ArgIndex, k: ArgIndex, vm: &VirtualMachine) -> PyResult<BigInt> {
771        let mut k = k.as_bigint();
772        let n = n.as_bigint();
773        let one = BigInt::one();
774        let zero = BigInt::zero();
775
776        if n.is_negative() || k.is_negative() {
777            return Err(vm.new_value_error("comb() not defined for negative values".to_owned()));
778        }
779
780        let temp = n - k;
781        if temp.is_negative() {
782            return Ok(zero);
783        }
784
785        if temp < *k {
786            k = &temp
787        }
788
789        if k.is_zero() {
790            return Ok(one);
791        }
792
793        let mut result = n.clone();
794        let mut factor = n.clone();
795        let mut current = one;
796        while current < *k {
797            factor -= 1;
798            current += 1;
799
800            result *= &factor;
801            result /= &current;
802        }
803
804        Ok(result)
805    }
806
807    #[pyfunction]
808    fn modf(x: ArgIntoFloat) -> (f64, f64) {
809        let x = *x;
810        if !x.is_finite() {
811            if x.is_infinite() {
812                return (0.0_f64.copysign(x), x);
813            } else if x.is_nan() {
814                return (x, x);
815            }
816        }
817
818        (x.fract(), x.trunc())
819    }
820
821    #[pyfunction]
822    fn nextafter(x: ArgIntoFloat, y: ArgIntoFloat) -> f64 {
823        float_ops::nextafter(*x, *y)
824    }
825
826    #[pyfunction]
827    fn ulp(x: ArgIntoFloat) -> f64 {
828        float_ops::ulp(*x)
829    }
830
831    fn fmod(x: f64, y: f64) -> f64 {
832        if y.is_infinite() && x.is_finite() {
833            return x;
834        }
835
836        x % y
837    }
838
839    #[pyfunction(name = "fmod")]
840    fn py_fmod(x: ArgIntoFloat, y: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
841        let x = *x;
842        let y = *y;
843
844        let r = fmod(x, y);
845
846        if r.is_nan() && !x.is_nan() && !y.is_nan() {
847            return Err(vm.new_value_error("math domain error".to_owned()));
848        }
849
850        Ok(r)
851    }
852
853    #[pyfunction]
854    fn remainder(x: ArgIntoFloat, y: ArgIntoFloat, vm: &VirtualMachine) -> PyResult<f64> {
855        let x = *x;
856        let y = *y;
857
858        if x.is_finite() && y.is_finite() {
859            if y == 0.0 {
860                return Err(vm.new_value_error("math domain error".to_owned()));
861            }
862
863            let absx = x.abs();
864            let absy = y.abs();
865            let modulus = absx % absy;
866
867            let c = absy - modulus;
868            let r = match modulus.partial_cmp(&c) {
869                Some(Ordering::Less) => modulus,
870                Some(Ordering::Greater) => -c,
871                _ => modulus - 2.0 * fmod(0.5 * (absx - modulus), absy),
872            };
873
874            return Ok(1.0_f64.copysign(x) * r);
875        }
876        if x.is_infinite() && !y.is_nan() {
877            return Err(vm.new_value_error("math domain error".to_owned()));
878        }
879        if x.is_nan() || y.is_nan() {
880            return Ok(f64::NAN);
881        }
882        if y.is_infinite() {
883            Ok(x)
884        } else {
885            Err(vm.new_value_error("math domain error".to_owned()))
886        }
887    }
888
889    #[derive(FromArgs)]
890    struct ProdArgs {
891        #[pyarg(positional)]
892        iterable: ArgIterable<PyObjectRef>,
893        #[pyarg(named, optional)]
894        start: OptionalArg<PyObjectRef>,
895    }
896
897    #[pyfunction]
898    fn prod(args: ProdArgs, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
899        let iter = args.iterable;
900
901        let mut result = args.start.unwrap_or_else(|| vm.new_pyobj(1));
902
903        // TODO: CPython has optimized implementation for this
904        // refer: https://github.com/python/cpython/blob/main/Modules/mathmodule.c#L3093-L3193
905        for obj in iter.iter(vm)? {
906            let obj = obj?;
907
908            result = vm
909                ._mul(&result, &obj)
910                .map_err(|_| vm.new_type_error("math type error".to_owned()))?;
911        }
912
913        Ok(result)
914    }
915}