beancount_parser/
amount.rs

1use std::{
2    borrow::Borrow,
3    fmt::{Debug, Display, Formatter},
4    ops::{Add, Div, Mul, Neg, Sub},
5    str::FromStr,
6    sync::Arc,
7};
8
9use nom::{
10    branch::alt,
11    bytes::complete::{take_while, take_while1},
12    character::complete::{char, one_of, satisfy, space0, space1},
13    combinator::{all_consuming, iterator, map_res, opt, recognize, verify},
14    sequence::{delimited, preceded, terminated, tuple},
15    Finish,
16};
17
18use crate::{IResult, Span};
19
20/// Price directive
21///
22/// # Example
23///
24/// ```
25/// use beancount_parser::{BeancountFile, DirectiveContent};
26/// let input = "2023-05-27 price CHF  4 PLN";
27/// let beancount: BeancountFile<f64> = input.parse().unwrap();
28/// let DirectiveContent::Price(price) = &beancount.directives[0].content else { unreachable!() };
29/// assert_eq!(price.currency.as_str(), "CHF");
30/// assert_eq!(price.amount.value, 4.0);
31/// assert_eq!(price.amount.currency.as_str(), "PLN");
32/// ```
33#[derive(Debug, Clone, PartialEq)]
34pub struct Price<D> {
35    /// Currency
36    pub currency: Currency,
37    /// Price of the currency
38    pub amount: Amount<D>,
39}
40
41/// Amount
42///
43/// Where `D` is the decimal type (like `f64` or `rust_decimal::Decimal`)
44///
45/// For an example, look at the [`Price`] directive
46#[derive(Debug, Clone, PartialEq)]
47pub struct Amount<D> {
48    /// The value (decimal) part
49    pub value: D,
50    /// Currency
51    pub currency: Currency,
52}
53
54/// Currency
55///
56/// One may use [`Currency::as_str`] to get the string representation of the currency
57///
58/// For an example, look at the [`Price`] directive
59#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq, Hash)]
60pub struct Currency(Arc<str>);
61
62impl Currency {
63    /// Returns underlying string representation
64    #[must_use]
65    pub fn as_str(&self) -> &str {
66        &self.0
67    }
68}
69
70impl Display for Currency {
71    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
72        Display::fmt(&self.0, f)
73    }
74}
75
76impl AsRef<str> for Currency {
77    fn as_ref(&self) -> &str {
78        self.0.as_ref()
79    }
80}
81
82impl Borrow<str> for Currency {
83    fn borrow(&self) -> &str {
84        self.0.borrow()
85    }
86}
87
88impl<'a> TryFrom<&'a str> for Currency {
89    type Error = crate::ConversionError;
90    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
91        value.parse().map_err(|_| crate::ConversionError)
92    }
93}
94
95impl FromStr for Currency {
96    type Err = crate::Error;
97    fn from_str(s: &str) -> Result<Self, Self::Err> {
98        let span = Span::new(s);
99        match all_consuming(currency)(span).finish() {
100            Ok((_, currency)) => Ok(currency),
101            Err(_) => Err(crate::Error::new(s, span)),
102        }
103    }
104}
105
106pub(crate) fn parse<D: Decimal>(input: Span<'_>) -> IResult<'_, Amount<D>> {
107    let (input, value) = expression(input)?;
108    let (input, _) = space1(input)?;
109    let (input, currency) = currency(input)?;
110    Ok((input, Amount { value, currency }))
111}
112
113pub(crate) fn expression<D: Decimal>(input: Span<'_>) -> IResult<'_, D> {
114    alt((negation, sum))(input)
115}
116
117fn sum<D: Decimal>(input: Span<'_>) -> IResult<'_, D> {
118    let (input, value) = product(input)?;
119    let mut iter = iterator(
120        input,
121        tuple((delimited(space0, one_of("+-"), space0), product)),
122    );
123    let value = iter.fold(value, |a, (op, b)| match op {
124        '+' => a + b,
125        '-' => a - b,
126        op => unreachable!("unsupported operator: {}", op),
127    });
128    let (input, ()) = iter.finish()?;
129    Ok((input, value))
130}
131
132fn product<D: Decimal>(input: Span<'_>) -> IResult<'_, D> {
133    let (input, value) = atom(input)?;
134    let mut iter = iterator(
135        input,
136        tuple((delimited(space0, one_of("*/"), space0), atom)),
137    );
138    let value = iter.fold(value, |a, (op, b)| match op {
139        '*' => a * b,
140        '/' => a / b,
141        op => unreachable!("unsupported operator: {}", op),
142    });
143    let (input, ()) = iter.finish()?;
144    Ok((input, value))
145}
146
147fn atom<D: Decimal>(input: Span<'_>) -> IResult<'_, D> {
148    alt((literal, group))(input)
149}
150
151fn group<D: Decimal>(input: Span<'_>) -> IResult<'_, D> {
152    delimited(
153        terminated(char('('), space0),
154        expression,
155        preceded(space0, char(')')),
156    )(input)
157}
158
159fn negation<D: Decimal>(input: Span<'_>) -> IResult<'_, D> {
160    let (input, _) = char('-')(input)?;
161    let (input, _) = space0(input)?;
162    let (input, expr) = group::<D>(input)?;
163    Ok((input, -expr))
164}
165
166fn literal<D: Decimal>(input: Span<'_>) -> IResult<'_, D> {
167    map_res(
168        recognize(tuple((
169            opt(char('-')),
170            space0,
171            take_while1(|c: char| c.is_numeric() || c == '.' || c == ','),
172        ))),
173        |s: Span<'_>| s.fragment().replace([',', ' '], "").parse(),
174    )(input)
175}
176
177pub(crate) fn price<D: Decimal>(input: Span<'_>) -> IResult<'_, Price<D>> {
178    let (input, currency) = currency(input)?;
179    let (input, _) = space1(input)?;
180    let (input, amount) = parse(input)?;
181    Ok((input, Price { currency, amount }))
182}
183
184pub(crate) fn currency(input: Span<'_>) -> IResult<'_, Currency> {
185    let (input, currency) = recognize(tuple((
186        satisfy(char::is_uppercase),
187        verify(
188            take_while(|c: char| {
189                c.is_uppercase() || c.is_numeric() || c == '-' || c == '_' || c == '.' || c == '\''
190            }),
191            |s: &Span<'_>| {
192                s.fragment()
193                    .chars()
194                    .last()
195                    .map_or(true, |c| c.is_uppercase() || c.is_numeric())
196            },
197        ),
198    )))(input)?;
199    Ok((input, Currency(Arc::from(*currency.fragment()))))
200}
201
202/// Decimal type to which amount values and expressions will be parsed into.
203///
204/// # Notable implementations
205///
206/// * `f64`
207/// * `Decimal` of the crate [rust_decimal]
208///
209/// [rust_decimal]: https://docs.rs/rust_decimal
210///
211pub trait Decimal:
212    FromStr
213    + Default
214    + Clone
215    + Debug
216    + Add<Output = Self>
217    + Sub<Output = Self>
218    + Mul<Output = Self>
219    + Div<Output = Self>
220    + Neg<Output = Self>
221    + PartialEq
222    + PartialOrd
223{
224}
225
226impl<D> Decimal for D where
227    D: FromStr
228        + Default
229        + Clone
230        + Debug
231        + Add<Output = Self>
232        + Sub<Output = Self>
233        + Mul<Output = Self>
234        + Div<Output = Self>
235        + Neg<Output = Self>
236        + PartialEq
237        + PartialOrd
238{
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use rstest::rstest;
245
246    #[rstest]
247    #[case("CHF")]
248    fn currency_from_str_should_parse_valid_currency(#[case] input: &str) {
249        let currency: Currency = input.parse().unwrap();
250        assert_eq!(currency.as_str(), input);
251    }
252
253    #[rstest]
254    #[case("")]
255    #[case(" ")]
256    #[case("oops")]
257    fn currency_from_str_should_not_parse_invalid_currency(#[case] input: &str) {
258        let currency: Result<Currency, _> = input.parse();
259        assert!(currency.is_err(), "{currency:?}");
260    }
261}
262
263#[cfg(test)]
264pub(crate) mod chumsky {
265    use super::{Amount, Currency};
266    use crate::{ChumskyError, ChumskyParser, Decimal};
267
268    use chumsky::{prelude::*, text::whitespace};
269
270    pub(crate) fn amount<D: Decimal + 'static>() -> impl ChumskyParser<Amount<D>> {
271        expression::<D>()
272            .then_ignore(whitespace())
273            .then(currency())
274            .map(|(value, currency)| Amount { value, currency })
275    }
276
277    pub(crate) fn currency() -> impl ChumskyParser<Currency> {
278        filter(|c: &char| c.is_ascii_uppercase())
279            .chain(
280                filter(|c: &char| c.is_ascii_uppercase() || c.is_ascii_digit())
281                    .or(one_of("'.-_"))
282                    .repeated(),
283            )
284            .collect::<String>()
285            .map(|s| Currency(s.into()))
286    }
287
288    pub(crate) fn expression<D: Decimal + 'static>() -> impl ChumskyParser<D> {
289        recursive(|expr| {
290            let atom = atom(expr);
291            let product = product(atom);
292            sum(product)
293        })
294        .labelled("expression")
295    }
296
297    fn sum<D: Decimal>(atom: impl ChumskyParser<D> + Clone) -> impl ChumskyParser<D> + Clone {
298        atom.clone()
299            .then(just('+').or(just('-')).padded().then(atom).repeated())
300            .foldl(|sum, (op, value)| match op {
301                '+' => sum + value,
302                '-' => sum - value,
303                op => unreachable!("unknown sum operator: {op}"),
304            })
305    }
306
307    fn product<D: Decimal>(atom: impl ChumskyParser<D> + Clone) -> impl ChumskyParser<D> + Clone {
308        atom.clone()
309            .then(just('*').or(just('/')).padded().then(atom).repeated())
310            .foldl(|product, (op, value)| match op {
311                '*' => product * value,
312                '/' => product / value,
313                op => unreachable!("unknown product operator: {op}"),
314            })
315    }
316
317    fn atom<D: Decimal>(expr: impl ChumskyParser<D> + Clone) -> impl ChumskyParser<D> + Clone {
318        just('-')
319            .or_not()
320            .then(
321                expr.padded()
322                    .delimited_by(just('('), just(')'))
323                    .or(literal::<D>()),
324            )
325            .map(|(neg_op, b)| if neg_op == Some('-') { -b } else { b })
326    }
327
328    fn literal<D: Decimal>() -> impl ChumskyParser<D> + Copy {
329        let digit = filter(|c: &char| c.is_ascii_digit());
330        let int_part = digit.repeated().at_least(1).chain::<char, _, _>(
331            just(',')
332                .chain(digit.repeated().at_least(1))
333                .repeated()
334                .flatten(),
335        );
336        let fract_part = just('.').chain::<char, _, _>(digit.repeated()).or_not();
337        int_part
338            .chain::<char, _, _>(fract_part)
339            .collect::<String>()
340            .try_map(|string: String, span| {
341                string
342                    .replace(',', "")
343                    .parse()
344                    .map_err(|_| ChumskyError::custom(span, "not a number"))
345            })
346            .labelled("numeric value")
347    }
348
349    #[cfg(test)]
350    mod tests {
351        use super::*;
352        use rstest::rstest;
353
354        #[rstest]
355        #[case::literal("42", 42)]
356        #[case::neg_literal("-42", -42)]
357        #[case::double_neg("-42", -42)]
358        #[case::neg_parenthesis("-(42)", -42)]
359        #[case::neg_parenthesis_2("-(-42)", 42)]
360        #[case::addition("1+2+3", 6)]
361        #[case::addition_with_space("1 + 2", 3)]
362        #[case::substraction("10-2-5", 3)]
363        #[case::substraction_with_space("5 - 3", 2)]
364        #[case::multiplication("2*3*4", 24)]
365        #[case::multiplication_with_space("2 * 3", 6)]
366        #[case::division("6/3", 2)]
367        #[case::division_with_space("6 / 3", 2)]
368        #[case::operator_priority("1 + 2 * 3", 7)]
369        #[case::parenthesis("(1+2)*3", 9)]
370        #[case::parenthesis_nested("( 1 + ( 2 + 2 ) * 4 ) * 3", 51)]
371        fn should_parse_valid_expression(#[case] input: &str, #[case] expected: i32) {
372            let result: i32 = expression().then_ignore(end()).parse(input).unwrap();
373            assert_eq!(result, expected);
374        }
375
376        #[rstest]
377        #[case::zero("0", 0.)]
378        #[case::zero_one("01", 1.)]
379        #[case::zero_dot("0.", 0.)]
380        #[case::int("42", 42.)]
381        #[case::with_fract_part("42.42", 42.42)]
382        #[case::thousand("1000", 1_000.)]
383        #[case::thousand_sep("1,000", 1_000.)]
384        fn should_parse_integer(#[case] input: &str, #[case] expected: f64) {
385            let value: f64 = literal().then_ignore(end()).parse(input).unwrap();
386            assert!(
387                (value - expected).abs() <= f64::EPSILON,
388                "{value} should equal {expected}"
389            );
390        }
391
392        #[rstest]
393        #[case::empty("")]
394        #[case::alpha("x")]
395        #[case::start_with_dot(".0")]
396        #[case::start_with_thousand_sep(",1")]
397        #[case::two_dots("1..")]
398        #[case::comma_in_fract_part("1.2,3")]
399        #[case::comma_dot("1,.0")]
400        fn should_not_parse_invalid_value(#[case] input: &str) {
401            let result = literal::<f64>().then_ignore(end()).parse(input);
402            assert!(result.is_err(), "{result:?}");
403        }
404
405        #[rstest]
406        #[case::normal("CHF")]
407        #[case::with_special_chars("USD'42-CHF_EUR.PLN")]
408        #[case::end_with_digit("A2")]
409        fn should_parse_valid_currency(#[case] input: &str) {
410            let currency: Currency = currency().parse(input).unwrap();
411            assert_eq!(currency.as_str(), input);
412        }
413
414        #[rstest]
415        #[case::empty("")]
416        #[case::lowercase("chf")]
417        #[case::start_with_digit("2A")]
418        fn should_not_parse_invalid_currency(#[case] input: &str) {
419            let result: Result<Currency, _> = currency().parse(input);
420            assert!(result.is_err(), "{result:?}");
421        }
422
423        #[rstest]
424        fn should_parse_valid_amount() {
425            let amount: Amount<i32> = amount().parse("1 + 3 CHF").unwrap();
426            assert_eq!(amount.value, 4);
427            assert_eq!(amount.currency.as_str(), "CHF");
428        }
429    }
430}