1use crate::error::{ArithmeticError, ParseError};
4use crate::rounding::RoundingMode;
5use core::cmp::Ordering;
6use core::fmt;
7use core::ops::{Add, Div, Mul, Neg, Sub};
8use core::str::FromStr;
9use num_traits::Signed;
10use rust_decimal::Decimal as RustDecimal;
11use serde::{Deserialize, Serialize};
12
13pub const MAX_SCALE: u32 = 28;
15
16#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
22#[serde(transparent)]
23pub struct Decimal(RustDecimal);
24
25impl Decimal {
26 pub const ZERO: Self = Self(RustDecimal::ZERO);
28
29 pub const ONE: Self = Self(RustDecimal::ONE);
31
32 pub const NEGATIVE_ONE: Self = Self(RustDecimal::NEGATIVE_ONE);
34
35 pub const TEN: Self = Self(RustDecimal::TEN);
37
38 pub const ONE_HUNDRED: Self = Self(RustDecimal::ONE_HUNDRED);
40
41 pub const ONE_THOUSAND: Self = Self(RustDecimal::ONE_THOUSAND);
43
44 pub const MAX: Self = Self(RustDecimal::MAX);
46
47 pub const MIN: Self = Self(RustDecimal::MIN);
49
50 #[must_use]
58 pub fn new(mantissa: i64, scale: u32) -> Self {
59 Self(RustDecimal::new(mantissa, scale))
60 }
61
62 #[must_use]
67 pub const fn from_parts(lo: u32, mid: u32, hi: u32, negative: bool, scale: u32) -> Self {
68 Self(RustDecimal::from_parts(lo, mid, hi, negative, scale))
69 }
70
71 #[must_use]
73 pub fn to_parts(self) -> (i128, u32) {
74 let unpacked = self.0.unpack();
75 let mantissa = i128::from(unpacked.lo)
76 | (i128::from(unpacked.mid) << 32)
77 | (i128::from(unpacked.hi) << 64);
78 let signed = if unpacked.negative {
79 -mantissa
80 } else {
81 mantissa
82 };
83 (signed, unpacked.scale)
84 }
85
86 #[must_use]
88 pub fn scale(self) -> u32 {
89 self.0.scale()
90 }
91
92 #[must_use]
94 pub fn is_zero(self) -> bool {
95 self.0.is_zero()
96 }
97
98 #[must_use]
100 pub fn is_negative(self) -> bool {
101 self.0.is_sign_negative()
102 }
103
104 #[must_use]
106 pub fn is_positive(self) -> bool {
107 self.0.is_sign_positive() && !self.0.is_zero()
108 }
109
110 #[must_use]
112 pub fn abs(self) -> Self {
113 Self(self.0.abs())
114 }
115
116 #[must_use]
118 pub fn signum(self) -> Self {
119 Self(self.0.signum())
120 }
121
122 #[must_use]
124 pub fn checked_add(self, other: Self) -> Option<Self> {
125 self.0.checked_add(other.0).map(Self)
126 }
127
128 #[must_use]
130 pub fn checked_sub(self, other: Self) -> Option<Self> {
131 self.0.checked_sub(other.0).map(Self)
132 }
133
134 #[must_use]
136 pub fn checked_mul(self, other: Self) -> Option<Self> {
137 self.0.checked_mul(other.0).map(Self)
138 }
139
140 #[must_use]
142 pub fn checked_div(self, other: Self) -> Option<Self> {
143 self.0.checked_div(other.0).map(Self)
144 }
145
146 #[must_use]
148 pub fn checked_rem(self, other: Self) -> Option<Self> {
149 self.0.checked_rem(other.0).map(Self)
150 }
151
152 #[must_use]
154 pub fn saturating_add(self, other: Self) -> Self {
155 Self(self.0.saturating_add(other.0))
156 }
157
158 #[must_use]
160 pub fn saturating_sub(self, other: Self) -> Self {
161 Self(self.0.saturating_sub(other.0))
162 }
163
164 #[must_use]
166 pub fn saturating_mul(self, other: Self) -> Self {
167 Self(self.0.saturating_mul(other.0))
168 }
169
170 pub fn try_add(self, other: Self) -> Result<Self, ArithmeticError> {
172 self.checked_add(other).ok_or(ArithmeticError::Overflow)
173 }
174
175 pub fn try_sub(self, other: Self) -> Result<Self, ArithmeticError> {
177 self.checked_sub(other).ok_or(ArithmeticError::Overflow)
178 }
179
180 pub fn try_mul(self, other: Self) -> Result<Self, ArithmeticError> {
182 self.checked_mul(other).ok_or(ArithmeticError::Overflow)
183 }
184
185 pub fn try_div(self, other: Self) -> Result<Self, ArithmeticError> {
187 if other.is_zero() {
188 return Err(ArithmeticError::DivisionByZero);
189 }
190 self.checked_div(other).ok_or(ArithmeticError::Overflow)
191 }
192
193 #[must_use]
195 pub fn round(self, dp: u32, mode: RoundingMode) -> Self {
196 Self(self.0.round_dp_with_strategy(dp, mode.to_rust_decimal()))
197 }
198
199 #[must_use]
201 pub fn round_dp(self, dp: u32) -> Self {
202 self.round(dp, RoundingMode::HalfEven)
203 }
204
205 #[must_use]
207 pub fn trunc(self, dp: u32) -> Self {
208 self.round(dp, RoundingMode::TowardZero)
209 }
210
211 #[must_use]
213 pub fn floor(self) -> Self {
214 Self(self.0.floor())
215 }
216
217 #[must_use]
219 pub fn ceil(self) -> Self {
220 Self(self.0.ceil())
221 }
222
223 #[must_use]
225 pub fn normalize(self) -> Self {
226 Self(self.0.normalize())
227 }
228
229 pub fn rescale(&mut self, scale: u32) -> Result<(), ArithmeticError> {
233 if scale > MAX_SCALE {
234 return Err(ArithmeticError::ScaleExceeded);
235 }
236 self.0.rescale(scale);
237 Ok(())
238 }
239
240 #[must_use]
242 pub fn min(self, other: Self) -> Self {
243 if self <= other {
244 self
245 } else {
246 other
247 }
248 }
249
250 #[must_use]
252 pub fn max(self, other: Self) -> Self {
253 if self >= other {
254 self
255 } else {
256 other
257 }
258 }
259
260 #[must_use]
262 pub fn clamp(self, min: Self, max: Self) -> Self {
263 self.max(min).min(max)
264 }
265
266 #[must_use]
268 pub fn into_inner(self) -> RustDecimal {
269 self.0
270 }
271
272 #[must_use]
274 pub fn from_inner(inner: RustDecimal) -> Self {
275 Self(inner)
276 }
277}
278
279impl Default for Decimal {
280 fn default() -> Self {
281 Self::ZERO
282 }
283}
284
285impl fmt::Debug for Decimal {
286 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287 write!(f, "Decimal({})", self.0)
288 }
289}
290
291impl fmt::Display for Decimal {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 fmt::Display::fmt(&self.0, f)
294 }
295}
296
297impl FromStr for Decimal {
298 type Err = ParseError;
299
300 fn from_str(s: &str) -> Result<Self, Self::Err> {
301 if s.is_empty() {
302 return Err(ParseError::Empty);
303 }
304 RustDecimal::from_str(s)
305 .map(Self)
306 .map_err(|_| ParseError::InvalidCharacter)
307 }
308}
309
310impl PartialOrd for Decimal {
311 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
312 Some(self.cmp(other))
313 }
314}
315
316impl Ord for Decimal {
317 fn cmp(&self, other: &Self) -> Ordering {
318 self.0.cmp(&other.0)
319 }
320}
321
322impl Neg for Decimal {
323 type Output = Self;
324
325 fn neg(self) -> Self::Output {
326 Self(-self.0)
327 }
328}
329
330impl Add for Decimal {
331 type Output = Self;
332
333 fn add(self, other: Self) -> Self::Output {
334 self.checked_add(other).expect("decimal overflow")
335 }
336}
337
338impl Sub for Decimal {
339 type Output = Self;
340
341 fn sub(self, other: Self) -> Self::Output {
342 self.checked_sub(other).expect("decimal overflow")
343 }
344}
345
346impl Mul for Decimal {
347 type Output = Self;
348
349 fn mul(self, other: Self) -> Self::Output {
350 self.checked_mul(other).expect("decimal overflow")
351 }
352}
353
354impl Div for Decimal {
355 type Output = Self;
356
357 fn div(self, other: Self) -> Self::Output {
358 self.checked_div(other).expect("decimal division error")
359 }
360}
361
362macro_rules! impl_from_int {
363 ($($t:ty),*) => {
364 $(
365 impl From<$t> for Decimal {
366 fn from(n: $t) -> Self {
367 Self(RustDecimal::from(n))
368 }
369 }
370 )*
371 };
372}
373
374impl_from_int!(i8, i16, i32, i64, u8, u16, u32, u64);
375
376impl From<i128> for Decimal {
377 fn from(n: i128) -> Self {
378 Self(RustDecimal::from(n))
379 }
380}
381
382impl From<u128> for Decimal {
383 fn from(n: u128) -> Self {
384 Self(RustDecimal::from(n))
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 extern crate alloc;
391 use super::*;
392 use alloc::string::ToString;
393
394 #[test]
395 fn zero_identity() {
396 let a = Decimal::from(42i64);
397 assert_eq!(a + Decimal::ZERO, a);
398 assert_eq!(a - Decimal::ZERO, a);
399 assert_eq!(a * Decimal::ZERO, Decimal::ZERO);
400 }
401
402 #[test]
403 fn one_identity() {
404 let a = Decimal::from(42i64);
405 assert_eq!(a * Decimal::ONE, a);
406 assert_eq!(a / Decimal::ONE, a);
407 }
408
409 #[test]
410 fn negation() {
411 let a = Decimal::from(42i64);
412 assert_eq!(-(-a), a);
413 assert_eq!(a + (-a), Decimal::ZERO);
414 }
415
416 #[test]
417 fn basic_arithmetic() {
418 let a = Decimal::new(100, 2);
419 let b = Decimal::new(200, 2);
420 assert_eq!(a + b, Decimal::new(300, 2));
421 assert_eq!(b - a, Decimal::new(100, 2));
422 assert_eq!(a * Decimal::from(2i64), b);
423 assert_eq!(b / Decimal::from(2i64), a);
424 }
425
426 #[test]
427 fn division_precision() {
428 let a = Decimal::from(1i64);
429 let b = Decimal::from(3i64);
430 let result = a / b;
431 assert_eq!(result.round_dp(6), Decimal::from_str("0.333333").unwrap());
432 }
433
434 #[test]
435 fn rounding_modes() {
436 let a = Decimal::from_str("2.5").unwrap();
437 assert_eq!(a.round(0, RoundingMode::HalfEven), Decimal::from(2i64));
438 assert_eq!(a.round(0, RoundingMode::HalfUp), Decimal::from(3i64));
439 assert_eq!(a.round(0, RoundingMode::Down), Decimal::from(2i64));
440 assert_eq!(a.round(0, RoundingMode::Up), Decimal::from(3i64));
441
442 let b = Decimal::from_str("3.5").unwrap();
443 assert_eq!(b.round(0, RoundingMode::HalfEven), Decimal::from(4i64));
444 }
445
446 #[test]
447 fn checked_operations() {
448 assert!(Decimal::MAX.checked_add(Decimal::ONE).is_none());
449 assert!(Decimal::MIN.checked_sub(Decimal::ONE).is_none());
450 assert!(Decimal::ZERO.checked_div(Decimal::ZERO).is_none());
451 }
452
453 #[test]
454 fn try_operations() {
455 assert!(matches!(
456 Decimal::MAX.try_add(Decimal::ONE),
457 Err(ArithmeticError::Overflow)
458 ));
459 assert!(matches!(
460 Decimal::ONE.try_div(Decimal::ZERO),
461 Err(ArithmeticError::DivisionByZero)
462 ));
463 }
464
465 #[test]
466 fn parse_and_display() {
467 let a: Decimal = "123.456".parse().unwrap();
468 assert_eq!(a.to_string(), "123.456");
469
470 let b: Decimal = "-0.001".parse().unwrap();
471 assert_eq!(b.to_string(), "-0.001");
472 }
473
474 #[test]
475 fn ordering() {
476 let a = Decimal::from(1i64);
477 let b = Decimal::from(2i64);
478 assert!(a < b);
479 assert!(b > a);
480 assert_eq!(a.min(b), a);
481 assert_eq!(a.max(b), b);
482 }
483
484 #[test]
485 fn abs_and_signum() {
486 let pos = Decimal::from(5i64);
487 let neg = Decimal::from(-5i64);
488
489 assert_eq!(pos.abs(), pos);
490 assert_eq!(neg.abs(), pos);
491 assert_eq!(pos.signum(), Decimal::ONE);
492 assert_eq!(neg.signum(), Decimal::NEGATIVE_ONE);
493 assert_eq!(Decimal::ZERO.signum(), Decimal::ZERO);
494 }
495
496 #[test]
497 fn clamp() {
498 let min = Decimal::from(0i64);
499 let max = Decimal::from(100i64);
500
501 assert_eq!(Decimal::from(50i64).clamp(min, max), Decimal::from(50i64));
502 assert_eq!(Decimal::from(-10i64).clamp(min, max), min);
503 assert_eq!(Decimal::from(150i64).clamp(min, max), max);
504 }
505}