div_int_procmacro/
lib.rs

1//! Proc macro implementation for the crate [`div-int`](http://docs.rs/div-int).
2extern crate proc_macro;
3
4use std::fmt::{Display, Formatter};
5use proc_macro2::{Span, TokenStream};
6use proc_macro_error::{Diagnostic, Level};
7use crate::ast::{DenominatorKind, Input, Operator};
8use crate::number_literal::OutputType;
9
10/// A compile-time constructor for `DivInt` literals.
11///
12/// There are two ways to invoke this macro:
13/// * `div_int!(N / D)` constructs a `DivInt` with numerator `N` and denominator `D`.
14/// * `div_int!(N * D)` constructs a `DivInt` with denominator `D` and the overall value of `N`.
15///
16/// # Examples
17/// ```
18/// use div_int::{div_int, DivInt};
19///
20/// // Numerator type inferred from context.
21/// assert_eq!(div_int!(15 / 30), DivInt::<u8, 30>::from_numerator(15));
22///
23/// // Denominator inferred from context.
24/// assert_eq!(div_int!(15 / _), DivInt::<u8, 30>::from_numerator(15));
25///
26/// // Explicit numerator type.
27/// assert_eq!(div_int!(15u16 / 30), DivInt::<u16, 30>::from_numerator(15));
28///
29/// // Represent the given fraction.
30/// assert_eq!(div_int!(1.5 * 30), DivInt::<u16, 30>::from_numerator(45));
31/// assert_eq!(f64::from(div_int!(1.5u8 * 30)), 1.5f64);
32///
33/// // Represent the given fraction with a specific numerator type.
34/// assert_eq!(div_int!(1.5u64 * 30), DivInt::<u64, 30>::from_numerator(45));
35/// ```
36
37#[proc_macro_error::proc_macro_error]
38#[proc_macro]
39pub fn div_int(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
40    let tokens: TokenStream = tokens.into();
41    let input = match Input::parse(&mut tokens.into_iter()) {
42        Ok(input) => input,
43        Err(err) => {
44            if let Some(span) = err.span() {
45                proc_macro_error::abort!(span, err);
46            } else {
47                proc_macro_error::abort!(Diagnostic::new(Level::Error, err.to_string()));
48            }
49        }
50    };
51    let code = match emit(&input) {
52        Ok(code) => code,
53        Err(err) => {
54            if let Some(span) = err.span() {
55                proc_macro_error::abort!(span, err);
56            } else {
57                proc_macro_error::abort!(Diagnostic::new(Level::Error, err.to_string()));
58            }
59        },
60    };
61
62    code.parse().expect("failed to produce valid macro output")
63}
64
65mod number_literal;
66mod ast;
67
68#[derive(Debug)]
69enum EmitError {
70    DivFormFloat(Span),
71    NotDivisible(Span),
72    OutsideTypeRange(Span, OutputType),
73    MulFormInferredDenominator(Span),
74}
75
76impl EmitError {
77    fn span(&self) -> Option<&Span> {
78        match self {
79            EmitError::DivFormFloat(span) => Some(span),
80            EmitError::NotDivisible(span) => Some(span),
81            EmitError::OutsideTypeRange(span, _) => Some(span),
82            EmitError::MulFormInferredDenominator(span) => Some(span),
83        }
84    }
85}
86
87impl Display for EmitError {
88    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
89        match self {
90            EmitError::DivFormFloat(_) => f.write_str("Floating point number cannot be used with the \"div\" form of div_int"),
91            EmitError::NotDivisible(_) => f.write_str("Denominator does not divide the provided numerator"),
92            EmitError::OutsideTypeRange(_, OutputType::Unknown) =>
93                f.write_str("Provided value is outside output type range"),
94            EmitError::OutsideTypeRange(_, output_type) => {
95                f.write_str("Provided value is outside output type range (")?;
96                f.write_str(output_type.to_rust_type())?;
97                f.write_str(")")
98            }
99            EmitError::MulFormInferredDenominator(_) =>
100                f.write_str("Denominator must be provided when using the \"mul\" form of div_int"),
101        }
102    }
103}
104
105fn emit(input: &Input) -> Result<String, EmitError> {
106    let numerator = match input.operator {
107        Operator::Div => {
108            if input.numerator.divider != 1 {
109                return Err(EmitError::DivFormFloat(input.numerator.span));
110            }
111            input.numerator.value
112        },
113        Operator::Mul => {
114            let divider = input.numerator.divider;
115            let denominator = match input.denominator.kind {
116                DenominatorKind::Inferred => return Err(EmitError::MulFormInferredDenominator(input.denominator.span)),
117                DenominatorKind::Explicit(d) => d
118            };
119
120            let Some(value) = input.numerator.value.checked_mul(denominator as i128) else {
121                return Err(EmitError::OutsideTypeRange(input.numerator.span, input.numerator.output_type));
122            };
123
124            if value % (divider as i128) != 0 {
125                return Err(EmitError::NotDivisible(input.numerator.span));
126            }
127
128            value / (divider as i128)
129        }
130    };
131    if !input.numerator.output_type.contains(numerator) {
132        return Err(EmitError::OutsideTypeRange(input.numerator.span, input.numerator.output_type));
133    }
134    let output_type = input.numerator.output_type.to_rust_type();
135
136
137    Ok(match input.denominator.kind {
138        DenominatorKind::Explicit(d) => format!("::div_int::DivInt::<{output_type}, {d}>::from_numerator({numerator})"),
139        DenominatorKind::Inferred => format!("::div_int::InferredDenominator::<{output_type}>::div_int({numerator})"),
140    })
141}
142
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use proc_macro2::{TokenStream, TokenTree};
148    use assert_matches::assert_matches;
149
150    fn to_tokens(input: &str) -> impl Iterator<Item = TokenTree> {
151        let stream: TokenStream = input.parse().expect("Failed to parse test input");
152        stream.into_iter()
153    }
154
155    fn parse_and_emit(input: &str) -> Result<String, EmitError> {
156        let stream: TokenStream = input.parse().expect("Failed to parse test input");
157        let input = Input::parse(&mut stream.into_iter()).expect("Failed to parse test input");
158        emit(&input)
159    }
160
161    #[test]
162    fn div_form() {
163        assert_matches!(parse_and_emit("3 / 5").as_deref(), Ok("::div_int::DivInt::<_, 5>::from_numerator(3)"));
164    }
165
166    #[test]
167    fn div_form_with_output_type() {
168        assert_matches!(parse_and_emit("3u8 / 5").as_deref(), Ok("::div_int::DivInt::<u8, 5>::from_numerator(3)"));
169    }
170
171    #[test]
172    fn div_form_with_implicit_types() {
173        assert_matches!(parse_and_emit("3 / _").as_deref(), Ok("::div_int::InferredDenominator::<_>::div_int(3)"));
174    }
175
176    #[test]
177    fn mul_form() {
178        assert_matches!(parse_and_emit("1.5u16 * 10").as_deref(), Ok("::div_int::DivInt::<u16, 10>::from_numerator(15)"));
179    }
180
181    #[test]
182    fn div_form_float() {
183        assert_matches!(parse_and_emit("1.5 / 2"), Err(EmitError::DivFormFloat(_)));
184    }
185
186    #[test]
187    fn outside_type_range() {
188        assert_matches!(parse_and_emit("1.5u8 * 200"), Err(EmitError::OutsideTypeRange(_, OutputType::U8)));
189    }
190
191    #[test]
192    fn mul_inferred_denominator() {
193        assert_matches!(parse_and_emit("1 * _"), Err(EmitError::MulFormInferredDenominator(_)));
194    }
195
196    #[test]
197    fn not_divisible() {
198        assert_matches!(parse_and_emit("1.5 * 3"), Err(EmitError::NotDivisible(_)));
199    }
200}