Skip to main content

rustpython_literal/
float.rs

1use crate::format::Case;
2use alloc::borrow::ToOwned;
3use alloc::format;
4use alloc::string::{String, ToString};
5use core::f64;
6use num_traits::{Float, Zero};
7
8pub fn parse_str(literal: &str) -> Option<f64> {
9    parse_inner(literal.trim().as_bytes())
10}
11
12pub fn parse_bytes(literal: &[u8]) -> Option<f64> {
13    parse_inner(literal.trim_ascii())
14}
15
16fn parse_inner(literal: &[u8]) -> Option<f64> {
17    use lexical_parse_float::{
18        FromLexicalWithOptions, NumberFormatBuilder, Options, format::PYTHON3_LITERAL,
19    };
20
21    // lexical-core's format::PYTHON_STRING is inaccurate
22    const PYTHON_STRING: u128 = NumberFormatBuilder::rebuild(PYTHON3_LITERAL)
23        .no_special(false)
24        .build_unchecked();
25    f64::from_lexical_with_options::<PYTHON_STRING>(literal, &Options::new()).ok()
26}
27
28pub fn is_integer(v: f64) -> bool {
29    v.is_finite() && v.fract() == 0.0
30}
31
32fn format_nan(case: Case) -> String {
33    let nan = match case {
34        Case::Lower => "nan",
35        Case::Upper => "NAN",
36    };
37
38    nan.to_string()
39}
40
41fn format_inf(case: Case) -> String {
42    let inf = match case {
43        Case::Lower => "inf",
44        Case::Upper => "INF",
45    };
46
47    inf.to_string()
48}
49
50pub const fn decimal_point_or_empty(precision: usize, alternate_form: bool) -> &'static str {
51    match (precision, alternate_form) {
52        (0, true) => ".",
53        _ => "",
54    }
55}
56
57pub fn format_fixed(precision: usize, magnitude: f64, case: Case, alternate_form: bool) -> String {
58    match magnitude {
59        magnitude if magnitude.is_finite() => {
60            let point = decimal_point_or_empty(precision, alternate_form);
61            let precision = core::cmp::min(precision, u16::MAX as usize);
62            format!("{magnitude:.precision$}{point}")
63        }
64        magnitude if magnitude.is_nan() => format_nan(case),
65        magnitude if magnitude.is_infinite() => format_inf(case),
66        _ => "".to_string(),
67    }
68}
69
70// Formats floats into Python style exponent notation, by first formatting in Rust style
71// exponent notation (`1.0000e0`), then convert to Python style (`1.0000e+00`).
72pub fn format_exponent(
73    precision: usize,
74    magnitude: f64,
75    case: Case,
76    alternate_form: bool,
77) -> String {
78    match magnitude {
79        magnitude if magnitude.is_finite() => {
80            let r_exp = format!("{magnitude:.precision$e}");
81            let mut parts = r_exp.splitn(2, 'e');
82            let base = parts.next().unwrap();
83            let exponent = parts.next().unwrap().parse::<i64>().unwrap();
84            let e = match case {
85                Case::Lower => 'e',
86                Case::Upper => 'E',
87            };
88            let point = decimal_point_or_empty(precision, alternate_form);
89            format!("{base}{point}{e}{exponent:+#03}")
90        }
91        magnitude if magnitude.is_nan() => format_nan(case),
92        magnitude if magnitude.is_infinite() => format_inf(case),
93        _ => "".to_string(),
94    }
95}
96
97/// If s represents a floating point value, trailing zeros and a possibly trailing
98/// decimal point will be removed.
99/// This function does NOT work with decimal commas.
100fn maybe_remove_trailing_redundant_chars(s: String, alternate_form: bool) -> String {
101    if !alternate_form && s.contains('.') {
102        // only truncate floating point values when not in alternate form
103        let s = remove_trailing_zeros(s);
104        remove_trailing_decimal_point(s)
105    } else {
106        s
107    }
108}
109
110fn remove_trailing_zeros(s: String) -> String {
111    let mut s = s;
112    while s.ends_with('0') {
113        s.pop();
114    }
115    s
116}
117
118fn remove_trailing_decimal_point(s: String) -> String {
119    let mut s = s;
120    if s.ends_with('.') {
121        s.pop();
122    }
123    s
124}
125
126pub fn format_general(
127    precision: usize,
128    magnitude: f64,
129    case: Case,
130    alternate_form: bool,
131    always_shows_fract: bool,
132) -> String {
133    match magnitude {
134        magnitude if magnitude.is_finite() => {
135            let r_exp = format!("{:.*e}", precision.saturating_sub(1), magnitude);
136            let mut parts = r_exp.splitn(2, 'e');
137            let base = parts.next().unwrap();
138            let exponent = parts.next().unwrap().parse::<i64>().unwrap();
139            if exponent < -4 || exponent + (always_shows_fract as i64) >= (precision as i64) {
140                let e = match case {
141                    Case::Lower => 'e',
142                    Case::Upper => 'E',
143                };
144                let magnitude = format!("{:.*}", precision + 1, base);
145                let base = maybe_remove_trailing_redundant_chars(magnitude, alternate_form);
146                let point = decimal_point_or_empty(precision.saturating_sub(1), alternate_form);
147                format!("{base}{point}{e}{exponent:+#03}")
148            } else {
149                let precision = ((precision as i64) - 1 - exponent) as usize;
150                let magnitude = format!("{magnitude:.precision$}");
151                let base = maybe_remove_trailing_redundant_chars(magnitude, alternate_form);
152                let point = decimal_point_or_empty(precision, alternate_form);
153                format!("{base}{point}")
154            }
155        }
156        magnitude if magnitude.is_nan() => format_nan(case),
157        magnitude if magnitude.is_infinite() => format_inf(case),
158        _ => "".to_string(),
159    }
160}
161
162// TODO: rewrite using format_general
163pub fn to_string(value: f64) -> String {
164    let lit = format!("{value:e}");
165    if let Some(position) = lit.find('e') {
166        let significand = &lit[..position];
167        let exponent = &lit[position + 1..];
168        let exponent = exponent.parse::<i32>().unwrap();
169        if exponent < 16 && exponent > -5 {
170            if is_integer(value) {
171                format!("{value:.1?}")
172            } else {
173                value.to_string()
174            }
175        } else {
176            format!("{significand}e{exponent:+#03}")
177        }
178    } else {
179        let mut s = value.to_string();
180        s.make_ascii_lowercase();
181        s
182    }
183}
184
185pub fn from_hex(s: &str) -> Option<f64> {
186    if let Ok(f) = hexf_parse::parse_hexf64(s, false) {
187        return Some(f);
188    }
189    match s.to_ascii_lowercase().as_str() {
190        "nan" | "+nan" | "-nan" => Some(f64::NAN),
191        "inf" | "infinity" | "+inf" | "+infinity" => Some(f64::INFINITY),
192        "-inf" | "-infinity" => Some(f64::NEG_INFINITY),
193        value => {
194            let mut hex = String::with_capacity(value.len());
195            let has_0x = value.contains("0x");
196            let has_p = value.contains('p');
197            let has_dot = value.contains('.');
198            let mut start = 0;
199
200            if !has_0x && value.starts_with('-') {
201                hex.push_str("-0x");
202                start += 1;
203            } else if !has_0x {
204                hex.push_str("0x");
205                if value.starts_with('+') {
206                    start += 1;
207                }
208            }
209
210            for (index, ch) in value.chars().enumerate() {
211                if ch == 'p' {
212                    if has_dot {
213                        hex.push('p');
214                    } else {
215                        hex.push_str(".p");
216                    }
217                } else if index >= start {
218                    hex.push(ch);
219                }
220            }
221
222            if !has_p && has_dot {
223                hex.push_str("p0");
224            } else if !has_p && !has_dot {
225                hex.push_str(".p0")
226            }
227
228            hexf_parse::parse_hexf64(hex.as_str(), false).ok()
229        }
230    }
231}
232
233pub fn to_hex(value: f64) -> String {
234    let (mantissa, exponent, sign) = value.integer_decode();
235    let sign_fmt = if sign < 0 { "-" } else { "" };
236    match value {
237        value if value.is_zero() => format!("{sign_fmt}0x0.0p+0"),
238        value if value.is_infinite() => format!("{sign_fmt}inf"),
239        value if value.is_nan() => "nan".to_owned(),
240        _ => {
241            const BITS: i16 = 52;
242            const FRACT_MASK: u64 = 0xf_ffff_ffff_ffff;
243            format!(
244                "{}{:#x}.{:013x}p{:+}",
245                sign_fmt,
246                mantissa >> BITS,
247                mantissa & FRACT_MASK,
248                exponent + BITS
249            )
250        }
251    }
252}
253
254#[test]
255fn test_to_hex() {
256    use rand::Rng;
257    for _ in 0..20000 {
258        let bytes = rand::rng().random::<u64>();
259        let f = f64::from_bits(bytes);
260        if !f.is_finite() {
261            continue;
262        }
263        let hex = to_hex(f);
264        // println!("{} -> {}", f, hex);
265        let roundtrip = hexf_parse::parse_hexf64(&hex, false).unwrap();
266        // println!("  -> {}", roundtrip);
267        assert!(f == roundtrip, "{f} {hex} {roundtrip}");
268    }
269}
270
271#[test]
272fn test_remove_trailing_zeros() {
273    assert!(remove_trailing_zeros(String::from("100")) == *"1");
274    assert!(remove_trailing_zeros(String::from("100.00")) == *"100.");
275
276    // leave leading zeros untouched
277    assert!(remove_trailing_zeros(String::from("001")) == *"001");
278
279    // leave strings untouched if they don't end with 0
280    assert!(remove_trailing_zeros(String::from("101")) == *"101");
281}
282
283#[test]
284fn test_remove_trailing_decimal_point() {
285    assert!(remove_trailing_decimal_point(String::from("100.")) == *"100");
286    assert!(remove_trailing_decimal_point(String::from("1.")) == *"1");
287
288    // leave leading decimal points untouched
289    assert!(remove_trailing_decimal_point(String::from(".5")) == *".5");
290}
291
292#[test]
293fn test_maybe_remove_trailing_redundant_chars() {
294    assert!(maybe_remove_trailing_redundant_chars(String::from("100."), true) == *"100.");
295    assert!(maybe_remove_trailing_redundant_chars(String::from("100."), false) == *"100");
296    assert!(maybe_remove_trailing_redundant_chars(String::from("1."), false) == *"1");
297    assert!(maybe_remove_trailing_redundant_chars(String::from("10.0"), false) == *"10");
298
299    // don't truncate integers
300    assert!(maybe_remove_trailing_redundant_chars(String::from("1000"), false) == *"1000");
301}