rustpython-common 0.2.0

General python functions and algorithms for use in RustPython
Documentation
use num_bigint::{BigInt, ToBigInt};
use num_traits::{Float, Signed, ToPrimitive, Zero};
use std::f64;

pub fn ufrexp(value: f64) -> (f64, i32) {
    if 0.0 == value {
        (0.0, 0i32)
    } else {
        let bits = value.to_bits();
        let exponent: i32 = ((bits >> 52) & 0x7ff) as i32 - 1022;
        let mantissa_bits = bits & (0x000f_ffff_ffff_ffff) | (1022 << 52);
        (f64::from_bits(mantissa_bits), exponent)
    }
}

/// Equate an integer to a float.
///
/// Returns true if and only if, when converted to each others types, both are equal.
///
/// # Examples
///
/// ```
/// use num_bigint::BigInt;
/// use rustpython_common::float_ops::eq_int;
/// let a = 1.0f64;
/// let b = BigInt::from(1);
/// let c = 2.0f64;
/// assert!(eq_int(a, &b));
/// assert!(!eq_int(c, &b));
/// ```
///
pub fn eq_int(value: f64, other: &BigInt) -> bool {
    if let (Some(self_int), Some(other_float)) = (value.to_bigint(), other.to_f64()) {
        value == other_float && self_int == *other
    } else {
        false
    }
}

pub fn lt_int(value: f64, other_int: &BigInt) -> bool {
    match (value.to_bigint(), other_int.to_f64()) {
        (Some(self_int), Some(other_float)) => value < other_float || self_int < *other_int,
        // finite float, other_int too big for float,
        // the result depends only on other_int’s sign
        (Some(_), None) => other_int.is_positive(),
        // infinite float must be bigger or lower than any int, depending on its sign
        _ if value.is_infinite() => value.is_sign_negative(),
        // NaN, always false
        _ => false,
    }
}

pub fn gt_int(value: f64, other_int: &BigInt) -> bool {
    match (value.to_bigint(), other_int.to_f64()) {
        (Some(self_int), Some(other_float)) => value > other_float || self_int > *other_int,
        // finite float, other_int too big for float,
        // the result depends only on other_int’s sign
        (Some(_), None) => other_int.is_negative(),
        // infinite float must be bigger or lower than any int, depending on its sign
        _ if value.is_infinite() => value.is_sign_positive(),
        // NaN, always false
        _ => false,
    }
}

pub fn parse_str(literal: &str) -> Option<f64> {
    parse_inner(literal.trim().as_bytes())
}

pub fn parse_bytes(literal: &[u8]) -> Option<f64> {
    parse_inner(trim_slice(literal, |b| b.is_ascii_whitespace()))
}

fn trim_slice<T>(v: &[T], mut trim: impl FnMut(&T) -> bool) -> &[T] {
    let mut it = v.iter();
    // it.take_while_ref(&mut trim).for_each(drop);
    // hmm.. `&mut slice::Iter<_>` is not `Clone`
    // it.by_ref().rev().take_while_ref(&mut trim).for_each(drop);
    while it.clone().next().map_or(false, &mut trim) {
        it.next();
    }
    while it.clone().next_back().map_or(false, &mut trim) {
        it.next_back();
    }
    it.as_slice()
}

fn parse_inner(literal: &[u8]) -> Option<f64> {
    use lexical_parse_float::{
        format::PYTHON3_LITERAL, FromLexicalWithOptions, NumberFormatBuilder, Options,
    };
    // lexical-core's format::PYTHON_STRING is inaccurate
    const PYTHON_STRING: u128 = NumberFormatBuilder::rebuild(PYTHON3_LITERAL)
        .no_special(false)
        .build();
    f64::from_lexical_with_options::<PYTHON_STRING>(literal, &Options::new()).ok()
}

pub fn is_integer(v: f64) -> bool {
    (v - v.round()).abs() < f64::EPSILON
}

#[derive(Debug)]
pub enum Case {
    Lower,
    Upper,
}

fn format_nan(case: Case) -> String {
    let nan = match case {
        Case::Lower => "nan",
        Case::Upper => "NAN",
    };

    nan.to_string()
}

fn format_inf(case: Case) -> String {
    let inf = match case {
        Case::Lower => "inf",
        Case::Upper => "INF",
    };

    inf.to_string()
}

pub fn format_fixed(precision: usize, magnitude: f64, case: Case) -> String {
    match magnitude {
        magnitude if magnitude.is_finite() => format!("{magnitude:.precision$}"),
        magnitude if magnitude.is_nan() => format_nan(case),
        magnitude if magnitude.is_infinite() => format_inf(case),
        _ => "".to_string(),
    }
}

// Formats floats into Python style exponent notation, by first formatting in Rust style
// exponent notation (`1.0000e0`), then convert to Python style (`1.0000e+00`).
pub fn format_exponent(precision: usize, magnitude: f64, case: Case) -> String {
    match magnitude {
        magnitude if magnitude.is_finite() => {
            let r_exp = format!("{magnitude:.precision$e}");
            let mut parts = r_exp.splitn(2, 'e');
            let base = parts.next().unwrap();
            let exponent = parts.next().unwrap().parse::<i64>().unwrap();
            let e = match case {
                Case::Lower => 'e',
                Case::Upper => 'E',
            };
            format!("{base}{e}{exponent:+#03}")
        }
        magnitude if magnitude.is_nan() => format_nan(case),
        magnitude if magnitude.is_infinite() => format_inf(case),
        _ => "".to_string(),
    }
}

/// If s represents a floating point value, trailing zeros and a possibly trailing
/// decimal point will be removed.
/// This function does NOT work with decimal commas.
fn maybe_remove_trailing_redundant_chars(s: String, alternate_form: bool) -> String {
    if !alternate_form && s.contains('.') {
        // only truncate floating point values when not in alternate form
        let s = remove_trailing_zeros(s);
        remove_trailing_decimal_point(s)
    } else {
        s
    }
}

fn remove_trailing_zeros(s: String) -> String {
    let mut s = s;
    while s.ends_with('0') {
        s.pop();
    }
    s
}

fn remove_trailing_decimal_point(s: String) -> String {
    let mut s = s;
    if s.ends_with('.') {
        s.pop();
    }
    s
}

pub fn format_general(
    precision: usize,
    magnitude: f64,
    case: Case,
    alternate_form: bool,
    always_shows_fract: bool,
) -> String {
    match magnitude {
        magnitude if magnitude.is_finite() => {
            let r_exp = format!("{:.*e}", precision.saturating_sub(1), magnitude);
            let mut parts = r_exp.splitn(2, 'e');
            let base = parts.next().unwrap();
            let exponent = parts.next().unwrap().parse::<i64>().unwrap();
            if exponent < -4 || exponent + (always_shows_fract as i64) >= (precision as i64) {
                let e = match case {
                    Case::Lower => 'e',
                    Case::Upper => 'E',
                };

                let base = maybe_remove_trailing_redundant_chars(
                    format!("{:.*}", precision + 1, base),
                    alternate_form,
                );
                format!("{base}{e}{exponent:+#03}")
            } else {
                let precision = (precision as i64) - 1 - exponent;
                let precision = precision as usize;
                maybe_remove_trailing_redundant_chars(
                    format!("{magnitude:.precision$}"),
                    alternate_form,
                )
            }
        }
        magnitude if magnitude.is_nan() => format_nan(case),
        magnitude if magnitude.is_infinite() => format_inf(case),
        _ => "".to_string(),
    }
}

pub fn to_string(value: f64) -> String {
    let lit = format!("{value:e}");
    if let Some(position) = lit.find('e') {
        let significand = &lit[..position];
        let exponent = &lit[position + 1..];
        let exponent = exponent.parse::<i32>().unwrap();
        if exponent < 16 && exponent > -5 {
            if is_integer(value) {
                format!("{value:.1?}")
            } else {
                value.to_string()
            }
        } else {
            format!("{significand}e{exponent:+#03}")
        }
    } else {
        let mut s = value.to_string();
        s.make_ascii_lowercase();
        s
    }
}

pub fn from_hex(s: &str) -> Option<f64> {
    if let Ok(f) = hexf_parse::parse_hexf64(s, false) {
        return Some(f);
    }
    match s.to_ascii_lowercase().as_str() {
        "nan" | "+nan" | "-nan" => Some(f64::NAN),
        "inf" | "infinity" | "+inf" | "+infinity" => Some(f64::INFINITY),
        "-inf" | "-infinity" => Some(f64::NEG_INFINITY),
        value => {
            let mut hex = String::with_capacity(value.len());
            let has_0x = value.contains("0x");
            let has_p = value.contains('p');
            let has_dot = value.contains('.');
            let mut start = 0;

            if !has_0x && value.starts_with('-') {
                hex.push_str("-0x");
                start += 1;
            } else if !has_0x {
                hex.push_str("0x");
                if value.starts_with('+') {
                    start += 1;
                }
            }

            for (index, ch) in value.chars().enumerate() {
                if ch == 'p' {
                    if has_dot {
                        hex.push('p');
                    } else {
                        hex.push_str(".p");
                    }
                } else if index >= start {
                    hex.push(ch);
                }
            }

            if !has_p && has_dot {
                hex.push_str("p0");
            } else if !has_p && !has_dot {
                hex.push_str(".p0")
            }

            hexf_parse::parse_hexf64(hex.as_str(), false).ok()
        }
    }
}

pub fn to_hex(value: f64) -> String {
    let (mantissa, exponent, sign) = value.integer_decode();
    let sign_fmt = if sign < 0 { "-" } else { "" };
    match value {
        value if value.is_zero() => format!("{sign_fmt}0x0.0p+0"),
        value if value.is_infinite() => format!("{sign_fmt}inf"),
        value if value.is_nan() => "nan".to_owned(),
        _ => {
            const BITS: i16 = 52;
            const FRACT_MASK: u64 = 0xf_ffff_ffff_ffff;
            format!(
                "{}{:#x}.{:013x}p{:+}",
                sign_fmt,
                mantissa >> BITS,
                mantissa & FRACT_MASK,
                exponent + BITS
            )
        }
    }
}

pub fn div(v1: f64, v2: f64) -> Option<f64> {
    if v2 != 0.0 {
        Some(v1 / v2)
    } else {
        None
    }
}

pub fn mod_(v1: f64, v2: f64) -> Option<f64> {
    if v2 != 0.0 {
        let mut val = v1 % v2;
        if (val < 0.0) != (v2 < 0.0) {
            val += v2;
        }
        Some(val)
    } else {
        None
    }
}

pub fn floordiv(v1: f64, v2: f64) -> Option<f64> {
    if v2 != 0.0 {
        Some((v1 / v2).floor())
    } else {
        None
    }
}

pub fn divmod(v1: f64, v2: f64) -> Option<(f64, f64)> {
    if v2 != 0.0 {
        let mut m = v1 % v2;
        let mut d = (v1 - m) / v2;
        if v2.is_sign_negative() != m.is_sign_negative() {
            m += v2;
            d -= 1.0;
        }
        Some((d, m))
    } else {
        None
    }
}

// nextafter algorithm based off of https://gitlab.com/bronsonbdevost/next_afterf
#[allow(clippy::float_cmp)]
pub fn nextafter(x: f64, y: f64) -> f64 {
    if x == y {
        y
    } else if x.is_nan() || y.is_nan() {
        f64::NAN
    } else if x >= f64::INFINITY {
        f64::MAX
    } else if x <= f64::NEG_INFINITY {
        f64::MIN
    } else if x == 0.0 {
        f64::from_bits(1).copysign(y)
    } else {
        // next x after 0 if y is farther from 0 than x, otherwise next towards 0
        // the sign is a separate bit in floats, so bits+1 moves away from 0 no matter the float
        let b = x.to_bits();
        let bits = if (y > x) == (x > 0.0) { b + 1 } else { b - 1 };
        let ret = f64::from_bits(bits);
        if ret == 0.0 {
            ret.copysign(x)
        } else {
            ret
        }
    }
}

pub fn ulp(x: f64) -> f64 {
    if x.is_nan() {
        return x;
    }
    let x = x.abs();
    let x2 = nextafter(x, f64::INFINITY);
    if x2.is_infinite() {
        // special case: x is the largest positive representable float
        let x2 = nextafter(x, f64::NEG_INFINITY);
        x - x2
    } else {
        x2 - x
    }
}

pub fn round_float_digits(x: f64, ndigits: i32) -> Option<f64> {
    let float = if ndigits.is_zero() {
        let fract = x.fract();
        if (fract.abs() - 0.5).abs() < f64::EPSILON {
            if x.trunc() % 2.0 == 0.0 {
                x - fract
            } else {
                x + fract
            }
        } else {
            x.round()
        }
    } else {
        const NDIGITS_MAX: i32 =
            ((f64::MANTISSA_DIGITS as i32 - f64::MIN_EXP) as f64 * f64::consts::LOG10_2) as i32;
        const NDIGITS_MIN: i32 = -(((f64::MAX_EXP + 1) as f64 * f64::consts::LOG10_2) as i32);
        if ndigits > NDIGITS_MAX {
            x
        } else if ndigits < NDIGITS_MIN {
            0.0f64.copysign(x)
        } else {
            let (y, pow1, pow2) = if ndigits >= 0 {
                // according to cpython: pow1 and pow2 are each safe from overflow, but
                //                       pow1*pow2 ~= pow(10.0, ndigits) might overflow
                let (pow1, pow2) = if ndigits > 22 {
                    (10.0.powf((ndigits - 22) as f64), 1e22)
                } else {
                    (10.0.powf(ndigits as f64), 1.0)
                };
                let y = (x * pow1) * pow2;
                if !y.is_finite() {
                    return Some(x);
                }
                (y, pow1, Some(pow2))
            } else {
                let pow1 = 10.0.powf((-ndigits) as f64);
                (x / pow1, pow1, None)
            };
            let z = y.round();
            #[allow(clippy::float_cmp)]
            let z = if (y - z).abs() == 0.5 {
                2.0 * (y / 2.0).round()
            } else {
                z
            };
            let z = if let Some(pow2) = pow2 {
                // ndigits >= 0
                (z / pow2) / pow1
            } else {
                z * pow1
            };

            if !z.is_finite() {
                // overflow
                return None;
            }

            z
        }
    };
    Some(float)
}

#[test]
fn test_to_hex() {
    use rand::Rng;
    for _ in 0..20000 {
        let bytes = rand::thread_rng().gen::<[u64; 1]>();
        let f = f64::from_bits(bytes[0]);
        if !f.is_finite() {
            continue;
        }
        let hex = to_hex(f);
        // println!("{} -> {}", f, hex);
        let roundtrip = hexf_parse::parse_hexf64(&hex, false).unwrap();
        // println!("  -> {}", roundtrip);
        assert!(f == roundtrip, "{} {} {}", f, hex, roundtrip);
    }
}

#[test]
fn test_remove_trailing_zeros() {
    assert!(remove_trailing_zeros(String::from("100")) == *"1");
    assert!(remove_trailing_zeros(String::from("100.00")) == *"100.");

    // leave leading zeros untouched
    assert!(remove_trailing_zeros(String::from("001")) == *"001");

    // leave strings untouched if they don't end with 0
    assert!(remove_trailing_zeros(String::from("101")) == *"101");
}

#[test]
fn test_remove_trailing_decimal_point() {
    assert!(remove_trailing_decimal_point(String::from("100.")) == *"100");
    assert!(remove_trailing_decimal_point(String::from("1.")) == *"1");

    // leave leading decimal points untouched
    assert!(remove_trailing_decimal_point(String::from(".5")) == *".5");
}

#[test]
fn test_maybe_remove_trailing_redundant_chars() {
    assert!(maybe_remove_trailing_redundant_chars(String::from("100."), true) == *"100.");
    assert!(maybe_remove_trailing_redundant_chars(String::from("100."), false) == *"100");
    assert!(maybe_remove_trailing_redundant_chars(String::from("1."), false) == *"1");
    assert!(maybe_remove_trailing_redundant_chars(String::from("10.0"), false) == *"10");

    // don't truncate integers
    assert!(maybe_remove_trailing_redundant_chars(String::from("1000"), false) == *"1000");
}