rhai/packages/
arithmetic.rs

1use crate::plugin::*;
2use crate::{def_package, Position, RhaiError, RhaiResultOf, ERR, INT};
3#[cfg(feature = "no_std")]
4use std::prelude::v1::*;
5
6#[cfg(feature = "no_std")]
7#[cfg(not(feature = "no_float"))]
8use num_traits::Float;
9
10#[cold]
11#[inline(never)]
12pub fn make_err(msg: impl Into<String>) -> RhaiError {
13    ERR::ErrorArithmetic(msg.into(), Position::NONE).into()
14}
15
16macro_rules! gen_arithmetic_functions {
17    ($root:ident => $($arg_type:ident),+) => {
18        #[allow(non_snake_case)]
19        pub mod $root { $(pub mod $arg_type {
20            use super::super::*;
21
22            #[export_module]
23            pub mod functions {
24                #[rhai_fn(name = "+", return_raw)]
25                pub fn add(x: $arg_type, y: $arg_type) -> RhaiResultOf<$arg_type> {
26                    if cfg!(not(feature = "unchecked")) {
27                        x.checked_add(y).ok_or_else(|| make_err(format!("Addition overflow: {x} + {y}")))
28                    } else {
29                        Ok(x + y)
30                    }
31                }
32                #[rhai_fn(name = "-", return_raw)]
33                pub fn subtract(x: $arg_type, y: $arg_type) -> RhaiResultOf<$arg_type> {
34                    if cfg!(not(feature = "unchecked")) {
35                        x.checked_sub(y).ok_or_else(|| make_err(format!("Subtraction overflow: {x} - {y}")))
36                    } else {
37                        Ok(x - y)
38                    }
39                }
40                #[rhai_fn(name = "*", return_raw)]
41                pub fn multiply(x: $arg_type, y: $arg_type) -> RhaiResultOf<$arg_type> {
42                    if cfg!(not(feature = "unchecked")) {
43                        x.checked_mul(y).ok_or_else(|| make_err(format!("Multiplication overflow: {x} * {y}")))
44                    } else {
45                        Ok(x * y)
46                    }
47                }
48                #[rhai_fn(name = "/", return_raw)]
49                pub fn divide(x: $arg_type, y: $arg_type) -> RhaiResultOf<$arg_type> {
50                    if cfg!(not(feature = "unchecked")) {
51                        // Detect division by zero
52                        if y == 0 {
53                            Err(make_err(format!("Division by zero: {x} / {y}")))
54                        } else {
55                            x.checked_div(y).ok_or_else(|| make_err(format!("Division overflow: {x} / {y}")))
56                        }
57                    } else {
58                        Ok(x / y)
59                    }
60                }
61                #[rhai_fn(name = "%", return_raw)]
62                pub fn modulo(x: $arg_type, y: $arg_type) -> RhaiResultOf<$arg_type> {
63                    if cfg!(not(feature = "unchecked")) {
64                        x.checked_rem(y).ok_or_else(|| make_err(format!("Modulo division by zero or overflow: {x} % {y}")))
65                    } else {
66                        Ok(x % y)
67                    }
68                }
69                #[rhai_fn(name = "**", return_raw)]
70                pub fn power(x: $arg_type, y: INT) -> RhaiResultOf<$arg_type> {
71                    if cfg!(not(feature = "unchecked")) {
72                        if cfg!(not(feature = "only_i32")) && y > (u32::MAX as INT) {
73                            Err(make_err(format!("Exponential overflow: {x} ** {y}")))
74                        } else if y < 0 {
75                            Err(make_err(format!("Integer raised to a negative power: {x} ** {y}")))
76                        } else {
77                            x.checked_pow(y as u32).ok_or_else(|| make_err(format!("Exponential overflow: {x} ** {y}")))
78                        }
79                    } else {
80                        Ok(x.pow(y as u32))
81                    }
82                }
83
84                #[rhai_fn(name = "<<")]
85                pub fn shift_left(x: $arg_type, y: INT) -> $arg_type {
86                    if cfg!(not(feature = "unchecked")) {
87                        if cfg!(not(feature = "only_i32")) && y > (u32::MAX as INT) {
88                            0
89                        } else if y < 0 {
90                            shift_right(x, y.checked_abs().unwrap_or(INT::MAX))
91                        } else {
92                            x.checked_shl(y as u32).unwrap_or_else(|| 0)
93                        }
94                    } else if y < 0 {
95                        x >> -y
96                    } else {
97                        x << y
98                    }
99                }
100                #[rhai_fn(name = ">>")]
101                pub fn shift_right(x: $arg_type, y: INT) -> $arg_type {
102                    if cfg!(not(feature = "unchecked")) {
103                        if cfg!(not(feature = "only_i32")) && y > (u32::MAX as INT) {
104                            x.wrapping_shr(u32::MAX)
105                        } else if y < 0 {
106                            shift_left(x, y.checked_abs().unwrap_or(INT::MAX))
107                        } else {
108                            x.checked_shr(y as u32).unwrap_or_else(|| x.wrapping_shr(u32::MAX))
109                        }
110                    } else if y < 0 {
111                        x << -y
112                    } else {
113                        x >> y
114                    }
115                }
116                #[rhai_fn(name = "&")]
117                pub const fn binary_and(x: $arg_type, y: $arg_type) -> $arg_type {
118                    x & y
119                }
120                #[rhai_fn(name = "|")]
121                pub const fn binary_or(x: $arg_type, y: $arg_type) -> $arg_type {
122                    x | y
123                }
124                #[rhai_fn(name = "^")]
125                pub const fn binary_xor(x: $arg_type, y: $arg_type) -> $arg_type {
126                    x ^ y
127                }
128                /// Return true if the number is zero.
129                #[rhai_fn(get = "is_zero", name = "is_zero")]
130                pub const fn is_zero(x: $arg_type) -> bool {
131                    x == 0
132                }
133                /// Return true if the number is odd.
134                #[rhai_fn(get = "is_odd", name = "is_odd")]
135                pub const fn is_odd(x: $arg_type) -> bool {
136                    x % 2 != 0
137                }
138                /// Return true if the number is even.
139                #[rhai_fn(get = "is_even", name = "is_even")]
140                pub const fn is_even(x: $arg_type) -> bool {
141                    x % 2 == 0
142                }
143            }
144        })* }
145    }
146}
147
148macro_rules! gen_signed_functions {
149    ($root:ident => $($arg_type:ident),+) => {
150        #[allow(non_snake_case)]
151        pub mod $root { $(pub mod $arg_type {
152            use super::super::*;
153
154            #[export_module]
155            pub mod functions {
156                #[rhai_fn(name = "-", return_raw)]
157                pub fn neg(x: $arg_type) -> RhaiResultOf<$arg_type> {
158                    if cfg!(not(feature = "unchecked")) {
159                        x.checked_neg().ok_or_else(|| make_err(format!("Negation overflow: -{x}")))
160                    } else {
161                        Ok(-x)
162                    }
163                }
164                #[rhai_fn(name = "+")]
165                pub const fn plus(x: $arg_type) -> $arg_type {
166                    x
167                }
168                /// Return the absolute value of the number.
169                #[rhai_fn(return_raw)]
170                pub fn abs(x: $arg_type) -> RhaiResultOf<$arg_type> {
171                    if cfg!(not(feature = "unchecked")) {
172                        x.checked_abs().ok_or_else(|| make_err(format!("Negation overflow: -{x}")))
173                    } else {
174                        Ok(x.abs())
175                    }
176                }
177                /// Return the sign (as an integer) of the number according to the following:
178                ///
179                /// * `0` if the number is zero
180                /// * `1` if the number is positive
181                /// * `-1` if the number is negative
182                pub const fn sign(x: $arg_type) -> INT {
183                    x.signum() as INT
184                }
185            }
186        })* }
187    }
188}
189
190macro_rules! reg_functions {
191    ($mod_name:ident += $root:ident ; $($arg_type:ident),+ ) => { $(
192        combine_with_exported_module!($mod_name, "arithmetic", $root::$arg_type::functions);
193    )* }
194}
195
196def_package! {
197    /// Basic arithmetic package.
198    pub ArithmeticPackage(lib) {
199        lib.set_standard_lib(true);
200
201        combine_with_exported_module!(lib, "int", int_functions);
202        reg_functions!(lib += signed_basic; INT);
203
204        #[cfg(not(feature = "only_i32"))]
205        #[cfg(not(feature = "only_i64"))]
206        {
207            gen_arithmetic_functions!(arith_numbers => i8, u8, i16, u16, i32, u32, u64);
208            reg_functions!(lib += arith_numbers; i8, u8, i16, u16, i32, u32, u64);
209            gen_signed_functions!(signed_numbers => i8, i16, i32);
210            reg_functions!(lib += signed_numbers; i8, i16, i32);
211
212            #[cfg(not(target_family = "wasm"))]
213            {
214                gen_arithmetic_functions!(arith_numbers => i128, u128);
215                reg_functions!(lib += arith_numbers; i128, u128);
216                gen_signed_functions!(signed_numbers => i128);
217                reg_functions!(lib += signed_numbers; i128);
218            }
219        }
220
221        // Basic arithmetic for floating-point
222        #[cfg(not(feature = "no_float"))]
223        {
224            combine_with_exported_module!(lib, "f32", f32_functions);
225            combine_with_exported_module!(lib, "f64", f64_functions);
226        }
227
228        // Decimal functions
229        #[cfg(feature = "decimal")]
230        combine_with_exported_module!(lib, "decimal", decimal_functions);
231    }
232}
233
234#[export_module]
235mod int_functions {
236    /// Return true if the number is zero.
237    #[rhai_fn(get = "is_zero", name = "is_zero")]
238    pub const fn is_zero(x: INT) -> bool {
239        x == 0
240    }
241    /// Return true if the number is odd.
242    #[rhai_fn(get = "is_odd", name = "is_odd")]
243    pub const fn is_odd(x: INT) -> bool {
244        x % 2 != 0
245    }
246    /// Return true if the number is even.
247    #[rhai_fn(get = "is_even", name = "is_even")]
248    pub const fn is_even(x: INT) -> bool {
249        x % 2 == 0
250    }
251}
252
253gen_arithmetic_functions!(arith_basic => INT);
254gen_signed_functions!(signed_basic => INT);
255
256#[cfg(not(feature = "no_float"))]
257#[export_module]
258mod f32_functions {
259    #[cfg(not(feature = "f32_float"))]
260    #[allow(clippy::cast_precision_loss)]
261    pub mod basic_arithmetic {
262        #[rhai_fn(name = "+")]
263        pub fn add(x: f32, y: f32) -> f32 {
264            x + y
265        }
266        #[rhai_fn(name = "-")]
267        pub fn subtract(x: f32, y: f32) -> f32 {
268            x - y
269        }
270        #[rhai_fn(name = "*")]
271        pub fn multiply(x: f32, y: f32) -> f32 {
272            x * y
273        }
274        #[rhai_fn(name = "/")]
275        pub fn divide(x: f32, y: f32) -> f32 {
276            x / y
277        }
278        #[rhai_fn(name = "%")]
279        pub fn modulo(x: f32, y: f32) -> f32 {
280            x % y
281        }
282        #[rhai_fn(name = "**")]
283        pub fn pow_f_f(x: f32, y: f32) -> f32 {
284            x.powf(y)
285        }
286
287        #[rhai_fn(name = "+")]
288        pub fn add_if(x: INT, y: f32) -> f32 {
289            (x as f32) + y
290        }
291        #[rhai_fn(name = "+")]
292        pub fn add_fi(x: f32, y: INT) -> f32 {
293            x + (y as f32)
294        }
295        #[rhai_fn(name = "-")]
296        pub fn subtract_if(x: INT, y: f32) -> f32 {
297            (x as f32) - y
298        }
299        #[rhai_fn(name = "-")]
300        pub fn subtract_fi(x: f32, y: INT) -> f32 {
301            x - (y as f32)
302        }
303        #[rhai_fn(name = "*")]
304        pub fn multiply_if(x: INT, y: f32) -> f32 {
305            (x as f32) * y
306        }
307        #[rhai_fn(name = "*")]
308        pub fn multiply_fi(x: f32, y: INT) -> f32 {
309            x * (y as f32)
310        }
311        #[rhai_fn(name = "/")]
312        pub fn divide_if(x: INT, y: f32) -> f32 {
313            (x as f32) / y
314        }
315        #[rhai_fn(name = "/")]
316        pub fn divide_fi(x: f32, y: INT) -> f32 {
317            x / (y as f32)
318        }
319        #[rhai_fn(name = "%")]
320        pub fn modulo_if(x: INT, y: f32) -> f32 {
321            (x as f32) % y
322        }
323        #[rhai_fn(name = "%")]
324        pub fn modulo_fi(x: f32, y: INT) -> f32 {
325            x % (y as f32)
326        }
327    }
328
329    #[rhai_fn(name = "-")]
330    pub fn neg(x: f32) -> f32 {
331        -x
332    }
333    #[rhai_fn(name = "+")]
334    pub const fn plus(x: f32) -> f32 {
335        x
336    }
337    /// Return the absolute value of the floating-point number.
338    pub fn abs(x: f32) -> f32 {
339        x.abs()
340    }
341    /// Return the sign (as an integer) of the floating-point number according to the following:
342    ///
343    /// * `0` if the number is zero
344    /// * `1` if the number is positive
345    /// * `-1` if the number is negative
346    #[rhai_fn(return_raw)]
347    pub fn sign(x: f32) -> RhaiResultOf<INT> {
348        match x.signum() {
349            _ if x == 0.0 => Ok(0),
350            x if x.is_nan() => Err(make_err("Sign of NaN is undefined")),
351            x => Ok(x as INT),
352        }
353    }
354    /// Return true if the floating-point number is zero.
355    #[rhai_fn(get = "is_zero", name = "is_zero")]
356    pub fn is_zero(x: f32) -> bool {
357        x == 0.0
358    }
359    #[rhai_fn(name = "**", return_raw)]
360    pub fn pow_f_i(x: f32, y: INT) -> RhaiResultOf<f32> {
361        if cfg!(not(feature = "unchecked")) && y > (i32::MAX as INT) {
362            Err(make_err(format!(
363                "Number raised to too large an index: {x} ** {y}"
364            )))
365        } else {
366            #[allow(clippy::cast_possible_truncation, clippy::unnecessary_cast)]
367            Ok(x.powi(y as i32))
368        }
369    }
370}
371
372#[cfg(not(feature = "no_float"))]
373#[export_module]
374mod f64_functions {
375    #[cfg(feature = "f32_float")]
376    pub mod basic_arithmetic {
377        #[rhai_fn(name = "+")]
378        pub fn add(x: f64, y: f64) -> f64 {
379            x + y
380        }
381        #[rhai_fn(name = "-")]
382        pub fn subtract(x: f64, y: f64) -> f64 {
383            x - y
384        }
385        #[rhai_fn(name = "*")]
386        pub fn multiply(x: f64, y: f64) -> f64 {
387            x * y
388        }
389        #[rhai_fn(name = "/")]
390        pub fn divide(x: f64, y: f64) -> f64 {
391            x / y
392        }
393        #[rhai_fn(name = "%")]
394        pub fn modulo(x: f64, y: f64) -> f64 {
395            x % y
396        }
397        #[rhai_fn(name = "**")]
398        pub fn pow_f_f(x: f64, y: f64) -> f64 {
399            x.powf(y)
400        }
401
402        #[rhai_fn(name = "+")]
403        pub fn add_if(x: INT, y: f64) -> f64 {
404            (x as f64) + y
405        }
406        #[rhai_fn(name = "+")]
407        pub fn add_fi(x: f64, y: INT) -> f64 {
408            x + (y as f64)
409        }
410        #[rhai_fn(name = "-")]
411        pub fn subtract_if(x: INT, y: f64) -> f64 {
412            (x as f64) - y
413        }
414        #[rhai_fn(name = "-")]
415        pub fn subtract_fi(x: f64, y: INT) -> f64 {
416            x - (y as f64)
417        }
418        #[rhai_fn(name = "*")]
419        pub fn multiply_if(x: INT, y: f64) -> f64 {
420            (x as f64) * y
421        }
422        #[rhai_fn(name = "*")]
423        pub fn multiply_fi(x: f64, y: INT) -> f64 {
424            x * (y as f64)
425        }
426        #[rhai_fn(name = "/")]
427        pub fn divide_if(x: INT, y: f64) -> f64 {
428            (x as f64) / y
429        }
430        #[rhai_fn(name = "/")]
431        pub fn divide_fi(x: f64, y: INT) -> f64 {
432            x / (y as f64)
433        }
434        #[rhai_fn(name = "%")]
435        pub fn modulo_if(x: INT, y: f64) -> f64 {
436            (x as f64) % y
437        }
438        #[rhai_fn(name = "%")]
439        pub fn modulo_fi(x: f64, y: INT) -> f64 {
440            x % (y as f64)
441        }
442    }
443
444    #[rhai_fn(name = "-")]
445    pub fn neg(x: f64) -> f64 {
446        -x
447    }
448    #[rhai_fn(name = "+")]
449    pub const fn plus(x: f64) -> f64 {
450        x
451    }
452    /// Return the absolute value of the floating-point number.
453    pub fn abs(x: f64) -> f64 {
454        x.abs()
455    }
456    /// Return the sign (as an integer) of the floating-point number according to the following:
457    ///
458    /// * `0` if the number is zero
459    /// * `1` if the number is positive
460    /// * `-1` if the number is negative
461    #[rhai_fn(return_raw)]
462    pub fn sign(x: f64) -> RhaiResultOf<INT> {
463        match x.signum() {
464            _ if x == 0.0 => Ok(0),
465            x if x.is_nan() => Err(make_err("Sign of NaN is undefined")),
466            x => Ok(x as INT),
467        }
468    }
469    /// Return true if the floating-point number is zero.
470    #[rhai_fn(get = "is_zero", name = "is_zero")]
471    pub fn is_zero(x: f64) -> bool {
472        x == 0.0
473    }
474}
475
476#[cfg(feature = "decimal")]
477#[export_module]
478pub mod decimal_functions {
479    use rust_decimal::{prelude::Zero, Decimal};
480
481    #[cfg(not(feature = "unchecked"))]
482    pub mod builtin {
483        use rust_decimal::MathematicalOps;
484
485        #[rhai_fn(return_raw)]
486        pub fn add(x: Decimal, y: Decimal) -> RhaiResultOf<Decimal> {
487            x.checked_add(y)
488                .ok_or_else(|| make_err(format!("Addition overflow: {x} + {y}")))
489        }
490        #[rhai_fn(return_raw)]
491        pub fn subtract(x: Decimal, y: Decimal) -> RhaiResultOf<Decimal> {
492            x.checked_sub(y)
493                .ok_or_else(|| make_err(format!("Subtraction overflow: {x} - {y}")))
494        }
495        #[rhai_fn(return_raw)]
496        pub fn multiply(x: Decimal, y: Decimal) -> RhaiResultOf<Decimal> {
497            x.checked_mul(y)
498                .ok_or_else(|| make_err(format!("Multiplication overflow: {x} * {y}")))
499        }
500        #[rhai_fn(return_raw)]
501        pub fn divide(x: Decimal, y: Decimal) -> RhaiResultOf<Decimal> {
502            // Detect division by zero
503            if y == Decimal::zero() {
504                Err(make_err(format!("Division by zero: {x} / {y}")))
505            } else {
506                x.checked_div(y)
507                    .ok_or_else(|| make_err(format!("Division overflow: {x} / {y}")))
508            }
509        }
510        #[rhai_fn(return_raw)]
511        pub fn modulo(x: Decimal, y: Decimal) -> RhaiResultOf<Decimal> {
512            x.checked_rem(y)
513                .ok_or_else(|| make_err(format!("Modulo division by zero or overflow: {x} % {y}")))
514        }
515        #[rhai_fn(return_raw)]
516        pub fn power(x: Decimal, y: Decimal) -> RhaiResultOf<Decimal> {
517            // Raising to a very large power can take exponential time, so limit it to 1 million.
518            // TODO: Remove this limit when `rust-decimal` is updated with the fix.
519            if std::convert::TryInto::<u32>::try_into(y.round()).map_or(true, |v| v > 1_000_000) {
520                return Err(make_err(format!("Exponential overflow: {x} ** {y}")));
521            }
522            x.checked_powd(y)
523                .ok_or_else(|| make_err(format!("Exponential overflow: {x} ** {y}")))
524        }
525    }
526    #[rhai_fn(name = "-")]
527    pub fn neg(x: Decimal) -> Decimal {
528        -x
529    }
530    #[rhai_fn(name = "+")]
531    pub const fn plus(x: Decimal) -> Decimal {
532        x
533    }
534    /// Return the absolute value of the decimal number.
535    pub fn abs(x: Decimal) -> Decimal {
536        x.abs()
537    }
538    /// Return the sign (as an integer) of the decimal number according to the following:
539    ///
540    /// * `0` if the number is zero
541    /// * `1` if the number is positive
542    /// * `-1` if the number is negative
543    pub fn sign(x: Decimal) -> INT {
544        if x == Decimal::zero() {
545            0
546        } else if x.is_sign_negative() {
547            -1
548        } else {
549            1
550        }
551    }
552    /// Return true if the decimal number is zero.
553    #[rhai_fn(get = "is_zero", name = "is_zero")]
554    pub const fn is_zero(x: Decimal) -> bool {
555        x.is_zero()
556    }
557}