1use crate::error::{Error, Result};
19use bigdecimal::num_bigint::BigInt;
20use bigdecimal::num_traits::Zero;
21use bigdecimal::{BigDecimal, RoundingMode};
22use std::fmt;
23
24#[cfg(test)]
25use std::str::FromStr;
26
27pub const MAX_COMPACT_PRECISION: u32 = 18;
30
31#[derive(Debug, Clone, serde::Serialize)]
38pub struct Decimal {
39 precision: u32,
40 scale: u32,
41 long_val: Option<i64>,
43 decimal_val: Option<BigDecimal>,
45}
46
47impl Decimal {
48 pub fn precision(&self) -> u32 {
52 self.precision
53 }
54
55 pub fn scale(&self) -> u32 {
57 self.scale
58 }
59
60 pub fn is_compact(&self) -> bool {
62 self.precision <= MAX_COMPACT_PRECISION
63 }
64
65 pub fn is_compact_precision(precision: u32) -> bool {
67 precision <= MAX_COMPACT_PRECISION
68 }
69
70 pub fn to_big_decimal(&self) -> BigDecimal {
72 if let Some(bd) = &self.decimal_val {
73 bd.clone()
74 } else if let Some(long_val) = self.long_val {
75 BigDecimal::new(BigInt::from(long_val), self.scale as i64)
76 } else {
77 BigDecimal::new(BigInt::from(0), self.scale as i64)
79 }
80 }
81
82 pub fn to_unscaled_long(&self) -> Result<i64> {
84 if let Some(long_val) = self.long_val {
85 Ok(long_val)
86 } else {
87 let bd = self.to_big_decimal();
89 let (unscaled, _) = bd.as_bigint_and_exponent();
90 unscaled.try_into().map_err(|_| Error::IllegalArgument {
91 message: format!(
92 "Decimal unscaled value does not fit in i64: precision={}",
93 self.precision
94 ),
95 })
96 }
97 }
98
99 pub fn to_unscaled_bytes(&self) -> Vec<u8> {
101 let bd = self.to_big_decimal();
102 let (unscaled, _) = bd.as_bigint_and_exponent();
103 unscaled.to_signed_bytes_be()
104 }
105
106 pub fn from_arrow_decimal128(
109 i128_val: i128,
110 arrow_scale: i64,
111 precision: u32,
112 scale: u32,
113 ) -> Result<Self> {
114 let bd = BigDecimal::new(BigInt::from(i128_val), arrow_scale);
115 Self::from_big_decimal(bd, precision, scale)
116 }
117
118 pub fn from_big_decimal(bd: BigDecimal, precision: u32, scale: u32) -> Result<Self> {
123 let scaled = bd.with_scale_round(scale as i64, RoundingMode::HalfUp);
125
126 let (unscaled, exp) = scaled.as_bigint_and_exponent();
128
129 debug_assert_eq!(
131 exp, scale as i64,
132 "Scaled decimal exponent ({exp}) != expected scale ({scale})"
133 );
134
135 let actual_precision = Self::compute_precision(&unscaled);
136 if actual_precision > precision as usize {
137 return Err(Error::IllegalArgument {
138 message: format!(
139 "Decimal precision overflow: value has {actual_precision} digits but precision is {precision} (value: {scaled})"
140 ),
141 });
142 }
143
144 let long_val = if precision <= MAX_COMPACT_PRECISION {
146 Some(i64::try_from(&unscaled).map_err(|_| Error::IllegalArgument {
147 message: format!(
148 "Decimal mantissa exceeds i64 range for compact precision {precision}: unscaled={unscaled} (value={scaled})"
149 ),
150 })?)
151 } else {
152 None
153 };
154
155 Ok(Decimal {
156 precision,
157 scale,
158 long_val,
159 decimal_val: Some(scaled),
160 })
161 }
162
163 pub fn from_unscaled_long(unscaled_long: i64, precision: u32, scale: u32) -> Result<Self> {
165 if precision > MAX_COMPACT_PRECISION {
166 return Err(Error::IllegalArgument {
167 message: format!(
168 "Precision {precision} exceeds MAX_COMPACT_PRECISION ({MAX_COMPACT_PRECISION})"
169 ),
170 });
171 }
172
173 let actual_precision = Self::compute_precision(&BigInt::from(unscaled_long));
174 if actual_precision > precision as usize {
175 return Err(Error::IllegalArgument {
176 message: format!(
177 "Decimal precision overflow: unscaled value has {actual_precision} digits but precision is {precision}"
178 ),
179 });
180 }
181
182 Ok(Decimal {
183 precision,
184 scale,
185 long_val: Some(unscaled_long),
186 decimal_val: None,
187 })
188 }
189
190 pub fn from_unscaled_bytes(unscaled_bytes: &[u8], precision: u32, scale: u32) -> Result<Self> {
192 let unscaled = BigInt::from_signed_bytes_be(unscaled_bytes);
193 let bd = BigDecimal::new(unscaled, scale as i64);
194 Self::from_big_decimal(bd, precision, scale)
195 }
196
197 pub fn compute_precision(unscaled: &BigInt) -> usize {
199 if unscaled.is_zero() {
200 return 1;
201 }
202
203 unscaled.magnitude().to_str_radix(10).len()
206 }
207}
208
209impl fmt::Display for Decimal {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 write!(f, "{}", self.to_big_decimal())
212 }
213}
214
215impl PartialEq for Decimal {
217 fn eq(&self, other: &Self) -> bool {
218 self.cmp(other) == std::cmp::Ordering::Equal
221 }
222}
223
224impl Eq for Decimal {}
225
226impl PartialOrd for Decimal {
227 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
228 Some(self.cmp(other))
229 }
230}
231
232impl Ord for Decimal {
233 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
234 if self.is_compact() && other.is_compact() && self.scale == other.scale {
236 self.long_val.cmp(&other.long_val)
237 } else {
238 self.to_big_decimal().cmp(&other.to_big_decimal())
240 }
241 }
242}
243
244impl std::hash::Hash for Decimal {
245 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
246 self.to_big_decimal().hash(state);
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_precision_calculation() {
269 assert_eq!(Decimal::compute_precision(&BigInt::from(0)), 1);
271
272 assert_eq!(Decimal::compute_precision(&BigInt::from(10)), 2);
274 assert_eq!(Decimal::compute_precision(&BigInt::from(100)), 3);
275 assert_eq!(Decimal::compute_precision(&BigInt::from(12300)), 5);
276 assert_eq!(
277 Decimal::compute_precision(&BigInt::from(10000000000i64)),
278 11
279 );
280
281 let bd = BigDecimal::new(BigInt::from(1), 0);
283 assert!(
284 Decimal::from_big_decimal(bd.clone(), 1, 10).is_err(),
285 "Should reject: unscaled 10000000000 has 11 digits, precision=1 is too small"
286 );
287 assert!(
288 Decimal::from_big_decimal(bd, 11, 10).is_ok(),
289 "Should accept with correct precision=11"
290 );
291 }
292
293 #[test]
295 fn test_precision_validation() {
296 let test_cases = vec![
297 (10i64, 1, 2), (100i64, 2, 3), (10000000000i64, 10, 11), ];
301
302 for (unscaled, scale, min_precision) in test_cases {
303 let bd = BigDecimal::new(BigInt::from(unscaled), scale as i64);
304
305 assert!(Decimal::from_big_decimal(bd.clone(), min_precision - 1, scale).is_err());
307 assert!(Decimal::from_big_decimal(bd, min_precision, scale).is_ok());
309 }
310
311 let bd = BigDecimal::new(BigInt::from(i64::MAX), 0);
313 assert!(Decimal::from_big_decimal(bd, 5, 0).is_err());
314 }
315
316 #[test]
318 fn test_creation_and_representation() {
319 let compact = Decimal::from_unscaled_long(12345, 10, 2).unwrap();
321 assert_eq!(compact.precision(), 10);
322 assert_eq!(compact.scale(), 2);
323 assert!(compact.is_compact());
324 assert_eq!(compact.to_unscaled_long().unwrap(), 12345);
325 assert_eq!(compact.to_big_decimal().to_string(), "123.45");
326
327 let bd = BigDecimal::new(BigInt::from(12345), 0);
329 let non_compact = Decimal::from_big_decimal(bd, 28, 0).unwrap();
330 assert_eq!(non_compact.precision(), 28);
331 assert!(!non_compact.is_compact());
332 assert_eq!(
333 non_compact.to_unscaled_bytes(),
334 BigInt::from(12345).to_signed_bytes_be()
335 );
336
337 assert!(Decimal::is_compact_precision(18));
339 assert!(!Decimal::is_compact_precision(19));
340
341 let bd = BigDecimal::new(BigInt::from(12345), 3); let rounded = Decimal::from_big_decimal(bd, 10, 2).unwrap();
344 assert_eq!(rounded.to_unscaled_long().unwrap(), 1235); }
346
347 #[test]
349 fn test_serialization_roundtrip() {
350 let bd1 = BigDecimal::new(BigInt::from(1314567890123i64), 5); let decimal1 = Decimal::from_big_decimal(bd1.clone(), 15, 5).unwrap();
353 let (unscaled1, _) = bd1.as_bigint_and_exponent();
354 let from_bytes1 =
355 Decimal::from_unscaled_bytes(&unscaled1.to_signed_bytes_be(), 15, 5).unwrap();
356 assert_eq!(from_bytes1, decimal1);
357 assert_eq!(
358 from_bytes1.to_unscaled_bytes(),
359 unscaled1.to_signed_bytes_be()
360 );
361
362 let bd2 = BigDecimal::new(BigInt::from(12345678900987654321i128), 10);
364 let decimal2 = Decimal::from_big_decimal(bd2.clone(), 23, 10).unwrap();
365 let (unscaled2, _) = bd2.as_bigint_and_exponent();
366 let from_bytes2 =
367 Decimal::from_unscaled_bytes(&unscaled2.to_signed_bytes_be(), 23, 10).unwrap();
368 assert_eq!(from_bytes2, decimal2);
369 assert_eq!(
370 from_bytes2.to_unscaled_bytes(),
371 unscaled2.to_signed_bytes_be()
372 );
373 }
374
375 #[test]
377 fn test_equality_and_ordering() {
378 let d1 = Decimal::from_big_decimal(BigDecimal::new(BigInt::from(10), 1), 2, 1).unwrap(); let d2 = Decimal::from_big_decimal(BigDecimal::new(BigInt::from(100), 2), 3, 2).unwrap(); assert_eq!(d1, d2, "Numeric equality: 1.0 == 1.00");
382 assert_eq!(d1.cmp(&d2), std::cmp::Ordering::Equal);
383
384 let small = Decimal::from_unscaled_long(10, 5, 0).unwrap();
386 let large = Decimal::from_unscaled_long(15, 5, 0).unwrap();
387 assert!(small < large);
388 assert_eq!(small.cmp(&large), std::cmp::Ordering::Less);
389
390 let negative_large = Decimal::from_unscaled_long(-10, 5, 0).unwrap(); let negative_small = Decimal::from_unscaled_long(-15, 5, 0).unwrap(); assert!(negative_small < negative_large); assert_eq!(
395 negative_small.cmp(&negative_large),
396 std::cmp::Ordering::Less
397 );
398
399 let positive = Decimal::from_unscaled_long(5, 5, 0).unwrap();
401 let negative = Decimal::from_unscaled_long(-5, 5, 0).unwrap();
402 assert!(negative < positive);
403 assert_eq!(negative.cmp(&positive), std::cmp::Ordering::Less);
404
405 let original = Decimal::from_unscaled_long(10, 5, 0).unwrap();
407 assert_eq!(original.clone(), original);
408 assert_eq!(
409 Decimal::from_unscaled_long(original.to_unscaled_long().unwrap(), 5, 0).unwrap(),
410 original
411 );
412 }
413
414 #[test]
416 fn test_hash_equals_contract() {
417 use std::collections::hash_map::DefaultHasher;
418 use std::hash::{Hash, Hasher};
419
420 let d1 = Decimal::from_big_decimal(BigDecimal::new(BigInt::from(10), 1), 2, 1).unwrap(); let d2 = Decimal::from_big_decimal(BigDecimal::new(BigInt::from(100), 2), 3, 2).unwrap(); assert_eq!(d1, d2);
425
426 let mut hasher1 = DefaultHasher::new();
428 d1.hash(&mut hasher1);
429 let hash1 = hasher1.finish();
430
431 let mut hasher2 = DefaultHasher::new();
432 d2.hash(&mut hasher2);
433 let hash2 = hasher2.finish();
434
435 assert_eq!(hash1, hash2, "Equal decimals must have equal hashes");
436
437 let mut map = std::collections::HashMap::new();
439 map.insert(d1.clone(), "value");
440 assert_eq!(map.get(&d2), Some(&"value"));
441 }
442
443 #[test]
445 fn test_edge_cases() {
446 let zero_compact = Decimal::from_unscaled_long(0, 5, 2).unwrap();
448 assert_eq!(
449 zero_compact.to_big_decimal(),
450 BigDecimal::new(BigInt::from(0), 2)
451 );
452
453 let zero_non_compact =
454 Decimal::from_big_decimal(BigDecimal::new(BigInt::from(0), 2), 20, 2).unwrap();
455 assert_eq!(
456 zero_non_compact.to_big_decimal(),
457 BigDecimal::new(BigInt::from(0), 2)
458 );
459
460 let large_bd = BigDecimal::from_str("123456789012345678901234567890123456789").unwrap();
462 let large = Decimal::from_big_decimal(large_bd, 39, 0).unwrap();
463 let double_val = large.to_big_decimal().to_string().parse::<f64>().unwrap();
464 assert!((double_val - 1.2345678901234568E38).abs() < 0.01);
465
466 let d1 = Decimal::from_big_decimal(BigDecimal::new(BigInt::from(50), 1), 10, 1).unwrap();
468 let d2 = Decimal::from_big_decimal(d1.to_big_decimal(), 10, 2).unwrap();
469 assert_eq!(d2.to_big_decimal().to_string(), "5.00");
470 assert_eq!(d2.scale(), 2);
471 }
472}