okane_core/syntax/
pretty_decimal.rs

1//! Module to define [PrettyDecimal], [Decimal] with formatting.
2
3use std::{convert::TryInto, fmt::Display, ops::Neg, str::FromStr};
4
5use bounded_static::{IntoBoundedStatic, ToBoundedStatic, ToStatic};
6use rust_decimal::Decimal;
7
8/// Decimal formatting type for pretty-printing.
9#[derive(Debug, PartialEq, Eq, Clone, Copy, ToStatic)]
10#[non_exhaustive]
11pub enum Format {
12    /// Decimal without no formatting, such as
13    /// `1234` or `1234.5`.
14    Plain,
15    /// Use `,` on every thousands, `.` for the decimal point.
16    Comma3Dot,
17}
18
19/// Decimal with the original format information encoded.
20#[derive(Debug, Default, PartialEq, Eq, Clone)]
21#[non_exhaustive] // Don't want to construct directly.
22pub struct PrettyDecimal {
23    /// Format of the decimal, None means there's no associated information.
24    pub format: Option<Format>,
25    pub value: Decimal,
26}
27
28impl Neg for PrettyDecimal {
29    type Output = Self;
30
31    fn neg(mut self) -> Self::Output {
32        self.set_sign_positive(!self.value.is_sign_positive());
33        self
34    }
35}
36
37impl ToBoundedStatic for PrettyDecimal {
38    type Static = Self;
39
40    fn to_static(&self) -> <Self as ToBoundedStatic>::Static {
41        self.clone()
42    }
43}
44
45impl IntoBoundedStatic for PrettyDecimal {
46    type Static = Self;
47
48    fn into_static(self) -> <Self as IntoBoundedStatic>::Static {
49        self
50    }
51}
52
53#[derive(thiserror::Error, PartialEq, Debug)]
54pub enum Error {
55    #[error("unexpected char {0} at {1}")]
56    UnexpectedChar(String, usize),
57    #[error("comma required at {0}")]
58    CommaRequired(usize),
59    #[error("unexpressible decimal {0}")]
60    InvalidDecimal(#[from] rust_decimal::Error),
61}
62
63impl PrettyDecimal {
64    /// Constructs unformatted PrettyDecimal.
65    #[inline]
66    pub const fn unformatted(value: Decimal) -> Self {
67        Self {
68            format: None,
69            value,
70        }
71    }
72
73    /// Constructs plain PrettyDecimal.
74    #[inline]
75    pub const fn plain(value: Decimal) -> Self {
76        Self {
77            format: Some(Format::Plain),
78            value,
79        }
80    }
81
82    /// Constructs comma3 PrettyDecimal.
83    #[inline]
84    pub const fn comma3dot(value: Decimal) -> Self {
85        Self {
86            format: Some(Format::Comma3Dot),
87            value,
88        }
89    }
90
91    /// Returns the current scale.
92    pub const fn scale(&self) -> u32 {
93        self.value.scale()
94    }
95
96    /// Rescale the underlying value.
97    pub fn rescale(&mut self, scale: u32) {
98        self.value.rescale(scale)
99    }
100
101    /// Sets the sign positive.
102    pub fn set_sign_positive(&mut self, positive: bool) {
103        self.value.set_sign_positive(positive)
104    }
105
106    /// Returns `true` if the value is positive.
107    pub const fn is_sign_positive(&self) -> bool {
108        self.value.is_sign_positive()
109    }
110}
111
112impl From<PrettyDecimal> for Decimal {
113    #[inline]
114    fn from(value: PrettyDecimal) -> Self {
115        value.value
116    }
117}
118
119impl FromStr for PrettyDecimal {
120    type Err = Error;
121
122    fn from_str(s: &str) -> Result<Self, Self::Err> {
123        // Only ASCII chars supported, use bytes.
124        let mut comma_pos = None;
125        let mut format = None;
126        let mut mantissa: i128 = 0;
127        let mut scale: Option<u32> = None;
128        let mut prefix_len = 0;
129        let mut sign = 1;
130        let aligned_comma = |offset, cp, pos| match (cp, pos) {
131            (None, _) if pos > offset && pos <= 3 + offset => true,
132            _ if cp == Some(pos) => true,
133            _ => false,
134        };
135        for (i, c) in s.bytes().enumerate() {
136            match (comma_pos, i, c) {
137                (_, 0, b'-') => {
138                    prefix_len = 1;
139                    sign = -1;
140                }
141                (_, _, b',') if aligned_comma(prefix_len, comma_pos, i) => {
142                    format = Some(Format::Comma3Dot);
143                    comma_pos = Some(i + 4);
144                }
145                (_, _, b'.') if comma_pos.is_none() || comma_pos == Some(i) => {
146                    scale = Some(0);
147                    comma_pos = None;
148                }
149                (Some(cp), _, _) if cp == i => {
150                    return Err(Error::CommaRequired(i));
151                }
152                _ if c.is_ascii_digit() => {
153                    if scale.is_none() && format.is_none() && i >= 3 + prefix_len {
154                        format = Some(Format::Plain);
155                    }
156                    mantissa = mantissa * 10 + (c as u32 - '0' as u32) as i128;
157                    scale = scale.map(|x| x + 1);
158                }
159                _ => {
160                    return Err(Error::UnexpectedChar(try_find_char(s, i, c), i));
161                }
162            }
163        }
164        let value = Decimal::try_from_i128_with_scale(sign * mantissa, scale.unwrap_or(0))?;
165        Ok(Self { format, value })
166    }
167}
168
169// Find the char at i. Note it returns String instead of char for complicated situations.
170fn try_find_char(s: &str, i: usize, chr: u8) -> String {
171    let begin = (0..=i).rev().find(|j| s.is_char_boundary(*j)).unwrap_or(0);
172    let end = (i + 1..s.len())
173        .find(|j| s.is_char_boundary(*j))
174        .unwrap_or(s.len());
175    s.get(begin..end)
176        .map(ToOwned::to_owned)
177        .unwrap_or_else(|| format!("{:?}", chr))
178}
179
180impl Display for PrettyDecimal {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        match self.format {
183            Some(Format::Plain) | None => self.value.fmt(f),
184            Some(Format::Comma3Dot) => {
185                if self.value.is_sign_negative() {
186                    write!(f, "-")?;
187                }
188                let mantissa = self.value.abs().mantissa().to_string();
189                let scale: usize = self
190                    .value
191                    .scale()
192                    .try_into()
193                    .expect("32-bit or larger bit only");
194                let mut remainder = mantissa.as_str();
195                // Here we assume mantissa is all ASCII (given it's [0-9.]+)
196                let mut initial_integer = true;
197                // caluclate the first comma position out of the integral portion digits.
198                let mut comma_pos = (mantissa.len() - scale) % 3;
199                if comma_pos == 0 {
200                    comma_pos = 3;
201                }
202                while remainder.len() > scale {
203                    if !initial_integer {
204                        write!(f, ",")?;
205                    }
206                    let section;
207                    (section, remainder) = remainder.split_at(comma_pos);
208                    write!(f, "{}", section)?;
209                    comma_pos = 3;
210                    initial_integer = false;
211                }
212                if initial_integer {
213                    write!(f, "0")?;
214                }
215                if !remainder.is_empty() {
216                    write!(f, ".{}", remainder)?;
217                }
218                Ok(())
219            }
220        }
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    use pretty_assertions::assert_eq;
229    use rust_decimal_macros::dec;
230
231    #[test]
232    fn from_str_unformatted() {
233        // If the number is below 1000, we can't tell if the number is plain or comma3dot.
234        // Thus we declare them as unformatted instead of plain.
235        assert_eq!(Ok(PrettyDecimal::unformatted(dec!(1))), "1".parse());
236        assert_eq!(Ok(PrettyDecimal::unformatted(dec!(-1))), "-1".parse());
237
238        assert_eq!(Ok(PrettyDecimal::unformatted(dec!(12))), "12".parse());
239        assert_eq!(Ok(PrettyDecimal::unformatted(dec!(-12))), "-12".parse());
240
241        assert_eq!(Ok(PrettyDecimal::unformatted(dec!(123))), "123".parse());
242        assert_eq!(Ok(PrettyDecimal::unformatted(dec!(-123))), "-123".parse());
243
244        assert_eq!(
245            Ok(PrettyDecimal::unformatted(dec!(0.123450))),
246            "0.123450".parse()
247        );
248    }
249
250    #[test]
251    fn from_str_plain() {
252        assert_eq!(Ok(PrettyDecimal::plain(dec!(1234))), "1234".parse());
253        assert_eq!(Ok(PrettyDecimal::plain(dec!(-1234))), "-1234".parse());
254
255        assert_eq!(Ok(PrettyDecimal::plain(dec!(1234567))), "1234567".parse());
256        assert_eq!(Ok(PrettyDecimal::plain(dec!(-1234567))), "-1234567".parse());
257
258        assert_eq!(Ok(PrettyDecimal::plain(dec!(1234.567))), "1234.567".parse());
259        assert_eq!(
260            Ok(PrettyDecimal::plain(dec!(-1234.567))),
261            "-1234.567".parse()
262        );
263    }
264
265    #[test]
266    fn from_str_comma() {
267        assert_eq!(Ok(PrettyDecimal::comma3dot(dec!(1234))), "1,234".parse());
268        assert_eq!(Ok(PrettyDecimal::comma3dot(dec!(-1234))), "-1,234".parse());
269
270        assert_eq!(Ok(PrettyDecimal::comma3dot(dec!(12345))), "12,345".parse());
271        assert_eq!(
272            Ok(PrettyDecimal::comma3dot(dec!(-12345))),
273            "-12,345".parse()
274        );
275
276        assert_eq!(
277            Ok(PrettyDecimal::comma3dot(dec!(123456))),
278            "123,456".parse()
279        );
280        assert_eq!(
281            Ok(PrettyDecimal::comma3dot(dec!(-123456))),
282            "-123,456".parse()
283        );
284
285        assert_eq!(
286            Ok(PrettyDecimal::comma3dot(dec!(1234567))),
287            "1,234,567".parse()
288        );
289        assert_eq!(
290            Ok(PrettyDecimal::comma3dot(dec!(-1234567))),
291            "-1,234,567".parse()
292        );
293
294        assert_eq!(
295            Ok(PrettyDecimal::comma3dot(dec!(1234.567))),
296            "1,234.567".parse()
297        );
298        assert_eq!(
299            Ok(PrettyDecimal::comma3dot(dec!(-1234.567))),
300            "-1,234.567".parse()
301        );
302    }
303
304    #[test]
305    fn display_plain() {
306        assert_eq!("1.234000", PrettyDecimal::plain(dec!(1.234000)).to_string());
307    }
308
309    #[test]
310    fn display_comma3_dot() {
311        assert_eq!("123", PrettyDecimal::comma3dot(dec!(123)).to_string());
312
313        assert_eq!("-1,234", PrettyDecimal::comma3dot(dec!(-1234)).to_string());
314
315        assert_eq!("0", PrettyDecimal::comma3dot(dec!(0)).to_string());
316
317        assert_eq!("0.1200", PrettyDecimal::comma3dot(dec!(0.1200)).to_string());
318
319        assert_eq!(
320            "1.234000",
321            PrettyDecimal::comma3dot(dec!(1.234000)).to_string()
322        );
323
324        assert_eq!("123.4", PrettyDecimal::comma3dot(dec!(123.4)).to_string());
325
326        assert_eq!(
327            "1,234,567.890120",
328            PrettyDecimal::comma3dot(dec!(1234567.890120)).to_string()
329        );
330    }
331
332    #[test]
333    fn scale_returns_correct_number() {
334        assert_eq!(0, PrettyDecimal::comma3dot(dec!(1230)).scale());
335        assert_eq!(1, PrettyDecimal::comma3dot(dec!(1230.4)).scale());
336    }
337}