Skip to main content

reifydb_type/value/decimal/
mod.rs

1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025 ReifyDB
3
4use std::{
5	cmp::Ordering,
6	fmt::{Display, Formatter},
7	ops::{Add, Deref, Div, Mul, Sub},
8	str::FromStr,
9};
10
11use bigdecimal::{BigDecimal as BigDecimalInner, FromPrimitive};
12use num_traits::{One, Zero};
13use serde::{
14	Deserialize, Deserializer, Serialize, Serializer,
15	de::{self, Visitor},
16};
17
18use super::{int::Int, uint::Uint};
19use crate::{error, error::Error, fragment::Fragment, value::r#type::Type};
20
21pub mod parse;
22
23use crate::error::diagnostic::number::invalid_number_format;
24
25#[repr(transparent)]
26#[derive(Clone, Debug)]
27pub struct Decimal(pub BigDecimalInner);
28
29impl Decimal {
30	pub fn zero() -> Self {
31		Self(BigDecimalInner::zero())
32	}
33
34	pub fn one() -> Self {
35		Self(BigDecimalInner::one())
36	}
37}
38
39impl Deref for Decimal {
40	type Target = BigDecimalInner;
41
42	fn deref(&self) -> &Self::Target {
43		&self.0
44	}
45}
46
47impl Decimal {
48	pub fn new(value: BigDecimalInner) -> Self {
49		Self(value)
50	}
51
52	pub fn from_bigdecimal(value: BigDecimalInner) -> Self {
53		Self(value)
54	}
55
56	pub fn with_scale(value: BigDecimalInner, scale: i64) -> Self {
57		Self(value.with_scale(scale))
58	}
59
60	pub fn from_i64(value: i64) -> Self {
61		Self(BigDecimalInner::from(value))
62	}
63
64	pub fn inner(&self) -> &BigDecimalInner {
65		&self.0
66	}
67
68	pub fn to_bigdecimal(self) -> BigDecimalInner {
69		self.0
70	}
71
72	pub fn negate(self) -> Self {
73		Self(-self.0)
74	}
75}
76
77impl PartialEq for Decimal {
78	fn eq(&self, other: &Self) -> bool {
79		self.0 == other.0
80	}
81}
82
83impl Eq for Decimal {}
84
85impl PartialOrd for Decimal {
86	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
87		Some(self.cmp(other))
88	}
89}
90
91impl Ord for Decimal {
92	fn cmp(&self, other: &Self) -> Ordering {
93		self.0.cmp(&other.0)
94	}
95}
96
97impl Display for Decimal {
98	fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
99		self.0.fmt(f)
100	}
101}
102
103impl std::hash::Hash for Decimal {
104	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
105		self.0.to_string().hash(state);
106	}
107}
108
109impl FromStr for Decimal {
110	type Err = Error;
111
112	fn from_str(s: &str) -> Result<Self, Self::Err> {
113		let big_decimal = BigDecimalInner::from_str(s)
114			.map_err(|_| error!(invalid_number_format(Fragment::None, Type::Decimal)))?;
115
116		Ok(Self(big_decimal))
117	}
118}
119
120impl From<i64> for Decimal {
121	fn from(value: i64) -> Self {
122		Self(BigDecimalInner::from(value))
123	}
124}
125
126impl From<i8> for Decimal {
127	fn from(value: i8) -> Self {
128		Self::from(value as i64)
129	}
130}
131
132impl From<i16> for Decimal {
133	fn from(value: i16) -> Self {
134		Self::from(value as i64)
135	}
136}
137
138impl From<i32> for Decimal {
139	fn from(value: i32) -> Self {
140		Self::from(value as i64)
141	}
142}
143
144impl From<i128> for Decimal {
145	fn from(value: i128) -> Self {
146		Self(BigDecimalInner::from(value))
147	}
148}
149
150impl From<u8> for Decimal {
151	fn from(value: u8) -> Self {
152		Self::from(value as i64)
153	}
154}
155
156impl From<u16> for Decimal {
157	fn from(value: u16) -> Self {
158		Self::from(value as i64)
159	}
160}
161
162impl From<u32> for Decimal {
163	fn from(value: u32) -> Self {
164		Self::from(value as i64)
165	}
166}
167
168impl From<u64> for Decimal {
169	fn from(value: u64) -> Self {
170		Self(BigDecimalInner::from(value))
171	}
172}
173
174impl From<u128> for Decimal {
175	fn from(value: u128) -> Self {
176		Self(BigDecimalInner::from(value))
177	}
178}
179
180impl From<f32> for Decimal {
181	fn from(value: f32) -> Self {
182		let inner = BigDecimalInner::from_f32(value).unwrap_or_else(|| BigDecimalInner::from(0));
183		Self(inner)
184	}
185}
186
187impl From<f64> for Decimal {
188	fn from(value: f64) -> Self {
189		let inner = BigDecimalInner::from_f64(value).unwrap_or_else(|| BigDecimalInner::from(0));
190		Self(inner)
191	}
192}
193
194impl From<BigDecimalInner> for Decimal {
195	fn from(value: BigDecimalInner) -> Self {
196		Self(value)
197	}
198}
199
200impl From<Int> for Decimal {
201	fn from(value: Int) -> Self {
202		Self(BigDecimalInner::from_bigint(value.0, 0))
203	}
204}
205
206impl From<Uint> for Decimal {
207	fn from(value: Uint) -> Self {
208		Self(BigDecimalInner::from_bigint(value.0, 0))
209	}
210}
211
212// Arithmetic operations
213impl Add for Decimal {
214	type Output = Self;
215
216	fn add(self, rhs: Self) -> Self::Output {
217		Self(self.0 + rhs.0)
218	}
219}
220
221impl Sub for Decimal {
222	type Output = Self;
223
224	fn sub(self, rhs: Self) -> Self::Output {
225		Self(self.0 - rhs.0)
226	}
227}
228
229impl Mul for Decimal {
230	type Output = Self;
231
232	fn mul(self, rhs: Self) -> Self::Output {
233		Self(self.0 * rhs.0)
234	}
235}
236
237impl Div for Decimal {
238	type Output = Self;
239
240	fn div(self, rhs: Self) -> Self::Output {
241		Self(self.0 / rhs.0)
242	}
243}
244
245// Reference arithmetic operations (to avoid cloning)
246impl Add<&Decimal> for &Decimal {
247	type Output = Decimal;
248
249	fn add(self, rhs: &Decimal) -> Self::Output {
250		Decimal(&self.0 + &rhs.0)
251	}
252}
253
254impl Sub<&Decimal> for &Decimal {
255	type Output = Decimal;
256
257	fn sub(self, rhs: &Decimal) -> Self::Output {
258		Decimal(&self.0 - &rhs.0)
259	}
260}
261
262impl Mul<&Decimal> for &Decimal {
263	type Output = Decimal;
264
265	fn mul(self, rhs: &Decimal) -> Self::Output {
266		Decimal(&self.0 * &rhs.0)
267	}
268}
269
270impl Div<&Decimal> for &Decimal {
271	type Output = Decimal;
272
273	fn div(self, rhs: &Decimal) -> Self::Output {
274		Decimal(&self.0 / &rhs.0)
275	}
276}
277
278impl Default for Decimal {
279	fn default() -> Self {
280		Self::zero()
281	}
282}
283
284// Serde implementation for string-based serialization
285// This works with both JSON and binary formats (bincode, rmp, etc.)
286impl Serialize for Decimal {
287	fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
288	where
289		S: Serializer,
290	{
291		serializer.serialize_str(&self.0.to_string())
292	}
293}
294
295struct DecimalVisitor;
296
297impl<'de> Visitor<'de> for DecimalVisitor {
298	type Value = Decimal;
299
300	fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
301		formatter.write_str("a decimal number as a string")
302	}
303
304	fn visit_str<E>(self, value: &str) -> Result<Decimal, E>
305	where
306		E: de::Error,
307	{
308		BigDecimalInner::from_str(value).map(Decimal).map_err(|e| E::custom(format!("invalid decimal: {}", e)))
309	}
310}
311
312impl<'de> Deserialize<'de> for Decimal {
313	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
314	where
315		D: Deserializer<'de>,
316	{
317		deserializer.deserialize_str(DecimalVisitor)
318	}
319}
320
321#[cfg(test)]
322pub mod tests {
323	use super::*;
324
325	#[test]
326	fn test_new_decimal_valid() {
327		let bd = BigDecimalInner::from_str("123.45").unwrap();
328		let decimal = Decimal::new(bd);
329		assert_eq!(decimal.to_string(), "123.45");
330	}
331
332	#[test]
333	fn test_from_str() {
334		let decimal = Decimal::from_str("123.45").unwrap();
335		assert_eq!(decimal.to_string(), "123.45");
336	}
337
338	#[test]
339	fn test_comparison() {
340		let d1 = Decimal::from_str("123.45").unwrap();
341		let d2 = Decimal::from_str("123.46").unwrap();
342		let d3 = Decimal::from_str("123.45").unwrap();
343
344		assert!(d1 < d2);
345		assert_eq!(d1, d3);
346	}
347
348	#[test]
349	fn test_display() {
350		let decimal = Decimal::from_str("123.45").unwrap();
351		assert_eq!(format!("{}", decimal), "123.45");
352	}
353
354	#[test]
355	fn test_serde_json() {
356		let decimal = Decimal::from_str("123.456789").unwrap();
357		let json = serde_json::to_string(&decimal).unwrap();
358		assert_eq!(json, "\"123.456789\"");
359
360		let deserialized: Decimal = serde_json::from_str(&json).unwrap();
361		assert_eq!(deserialized, decimal);
362	}
363
364	#[test]
365	fn test_serde_json_negative() {
366		let decimal = Decimal::from_str("-987.654321").unwrap();
367		let json = serde_json::to_string(&decimal).unwrap();
368		assert_eq!(json, "\"-987.654321\"");
369
370		let deserialized: Decimal = serde_json::from_str(&json).unwrap();
371		assert_eq!(deserialized, decimal);
372	}
373
374	#[test]
375	fn test_serde_json_zero() {
376		let decimal = Decimal::zero();
377		let json = serde_json::to_string(&decimal).unwrap();
378		assert_eq!(json, "\"0\"");
379
380		let deserialized: Decimal = serde_json::from_str(&json).unwrap();
381		assert_eq!(deserialized, decimal);
382	}
383
384	#[test]
385	fn test_serde_json_high_precision() {
386		let decimal = Decimal::from_str("123456789.123456789123456789").unwrap();
387		let json = serde_json::to_string(&decimal).unwrap();
388
389		let deserialized: Decimal = serde_json::from_str(&json).unwrap();
390		assert_eq!(deserialized, decimal);
391	}
392
393	#[test]
394	fn test_serde_postcard() {
395		let decimal = Decimal::from_str("123.456789").unwrap();
396		let encoded = postcard::to_stdvec(&decimal).unwrap();
397
398		let decoded: Decimal = postcard::from_bytes(&encoded).unwrap();
399		assert_eq!(decoded, decimal);
400	}
401
402	#[test]
403	fn test_serde_postcard_negative() {
404		let decimal = Decimal::from_str("-987.654321").unwrap();
405		let encoded = postcard::to_stdvec(&decimal).unwrap();
406
407		let decoded: Decimal = postcard::from_bytes(&encoded).unwrap();
408		assert_eq!(decoded, decimal);
409	}
410
411	#[test]
412	fn test_serde_postcard_zero() {
413		let decimal = Decimal::zero();
414		let encoded = postcard::to_stdvec(&decimal).unwrap();
415
416		let decoded: Decimal = postcard::from_bytes(&encoded).unwrap();
417		assert_eq!(decoded, decimal);
418	}
419
420	#[test]
421	fn test_serde_postcard_high_precision() {
422		let decimal = Decimal::from_str("123456789.123456789123456789").unwrap();
423		let encoded = postcard::to_stdvec(&decimal).unwrap();
424
425		let decoded: Decimal = postcard::from_bytes(&encoded).unwrap();
426		assert_eq!(decoded, decimal);
427	}
428
429	#[test]
430	fn test_serde_postcard_large_number() {
431		let decimal = Decimal::from_str("999999999999999999999999999999.999999999999999999999999").unwrap();
432		let encoded = postcard::to_stdvec(&decimal).unwrap();
433
434		let decoded: Decimal = postcard::from_bytes(&encoded).unwrap();
435		assert_eq!(decoded, decimal);
436	}
437}