1use crate::errors::MathError;
2use std::{
3 fmt,
4 ops::{Add, Div, Mul, Sub},
5 str::FromStr,
6};
7
8use cosmwasm_std as cw;
9
10#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Ord, PartialOrd)]
17pub enum Sign {
18 Positive,
19 Negative,
20 #[default]
21 Zero,
22}
23
24#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Debug)]
27pub struct DecimalExt {
28 sign: Sign,
29 dec: cw::Decimal,
30}
31
32impl DecimalExt {
33 pub fn zero() -> Self {
34 DecimalExt::default()
35 }
36
37 pub fn sign(&self) -> Sign {
39 self.sign
40 }
41
42 pub fn abc_cw_dec(&self) -> cw::Decimal {
44 self.dec
45 }
46
47 pub fn add(&self, other: Self) -> Self {
48 if self.sign == other.sign {
49 return DecimalExt {
50 sign: self.sign,
51 dec: self.dec.add(other.dec),
52 };
53 } else if other.dec.is_zero() {
54 return *self;
55 }
56
57 let self_dec_gt: bool = self.dec.ge(&other.dec);
58 let sign = if self_dec_gt { self.sign } else { other.sign };
59 let dec = if self_dec_gt {
60 self.dec.sub(other.dec) } else {
62 other.dec.sub(self.dec) };
64 let sign = if dec.is_zero() { Sign::Zero } else { sign };
65
66 DecimalExt { sign, dec }
67 }
68
69 pub fn neg(&self) -> Self {
70 match self.sign {
71 Sign::Positive => DecimalExt {
72 sign: Sign::Negative,
73 dec: self.dec,
74 },
75 Sign::Negative => DecimalExt {
76 sign: Sign::Positive,
77 dec: self.dec,
78 },
79 Sign::Zero => *self,
80 }
81 }
82
83 pub fn sub(&self, other: Self) -> Self {
84 self.add(other.neg())
85 }
86
87 pub fn mul(&self, other: Self) -> Self {
88 let dec = self.dec.mul(other.dec);
89 let sign = match (self.sign, other.sign) {
90 (Sign::Zero, _) | (_, Sign::Zero) => Sign::Zero,
91 (Sign::Positive, Sign::Positive)
92 | (Sign::Negative, Sign::Negative) => Sign::Positive,
93 (Sign::Positive, Sign::Negative)
94 | (Sign::Negative, Sign::Positive) => Sign::Negative,
95 };
96 DecimalExt { sign, dec }
97 }
98
99 pub fn quo(&self, other: Self) -> Result<Self, MathError> {
100 let sign = match (self.sign, other.sign) {
101 (Sign::Zero, _) => Sign::Zero,
102 (_, Sign::Zero) => return Err(MathError::DivisionByZero),
103 (Sign::Positive, Sign::Positive)
104 | (Sign::Negative, Sign::Negative) => Sign::Positive,
105 (Sign::Positive, Sign::Negative)
106 | (Sign::Negative, Sign::Positive) => Sign::Negative,
107 };
108 let dec = self.dec.div(other.dec);
109 Ok(DecimalExt { sign, dec })
110 }
111}
112
113impl From<cw::Decimal> for DecimalExt {
114 fn from(cw_dec: cw::Decimal) -> Self {
115 if cw_dec.is_zero() {
116 return DecimalExt::zero();
117 }
118 DecimalExt {
119 sign: Sign::Positive,
120 dec: cw_dec,
121 }
122 }
123}
124
125impl FromStr for DecimalExt {
126 type Err = MathError;
127
128 fn from_str(s: &str) -> Result<Self, Self::Err> {
132 let non_strict_sign = if s.starts_with('-') {
133 Sign::Negative
134 } else {
135 Sign::Positive
136 };
137
138 let abs_value = if let Some(s) = s.strip_prefix('-') {
139 s } else {
141 s
142 };
143
144 let cw_dec: cw::Decimal =
145 cw::Decimal::from_str(abs_value).map_err(|cw_std_err| {
146 MathError::CwDecParseError {
147 dec_str: s.to_string(),
148 err: cw_std_err,
149 }
150 })?;
151 let sign = if cw_dec.is_zero() {
152 Sign::Zero
153 } else {
154 non_strict_sign
155 };
156 Ok(DecimalExt { sign, dec: cw_dec })
157 }
158}
159
160impl fmt::Display for DecimalExt {
161 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162 let prefix = if self.sign == Sign::Negative { "-" } else { "" };
163 write!(f, "{}{}", prefix, self.dec)
164 }
165}
166
167pub struct SdkDec {
171 protobuf_repr: String,
172}
173
174impl SdkDec {
175 pub fn new(dec: &DecimalExt) -> Result<Self, MathError> {
176 Ok(Self {
177 protobuf_repr: dec.to_sdk_dec_pb_repr()?,
178 })
179 }
180
181 pub fn pb_repr(&self) -> String {
183 self.protobuf_repr.to_string()
184 }
185
186 pub fn from_dec(dec: DecimalExt) -> Result<Self, MathError> {
187 Self::new(&dec)
188 }
189
190 pub fn from_cw_dec(cw_dec: cw::Decimal) -> Result<Self, MathError> {
191 Self::new(&DecimalExt::from(cw_dec))
192 }
193}
194
195impl FromStr for SdkDec {
196 type Err = MathError;
197
198 fn from_str(s: &str) -> Result<Self, Self::Err> {
205 Self::new(&DecimalExt::from_str(s)?)
206 }
207}
208
209impl fmt::Display for SdkDec {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 let dec =
212 DecimalExt::from_sdk_dec(&self.pb_repr()).unwrap_or_else(|err| {
213 panic!(
214 "ParseError: could not marshal SdkDec {} to DecimalExt: {}",
215 self.protobuf_repr, err,
216 )
217 });
218 write!(f, "{}", dec)
219 }
220}
221
222impl DecimalExt {
223 pub fn precision_digits() -> usize {
224 18
225 }
226
227 pub fn to_sdk_dec(&self) -> Result<SdkDec, MathError> {
230 SdkDec::new(self)
231 }
232
233 pub fn to_sdk_dec_pb_repr(&self) -> Result<String, MathError> {
236 if self.dec.is_zero() {
237 return Ok("0".repeat(DecimalExt::precision_digits()));
238 }
239
240 let abs_str = self.dec.to_string();
242
243 let neg = self.sign == Sign::Negative;
245
246 let parts: Vec<&str> = abs_str.split('.').collect();
248 let (int_part, frac_part) = match parts.as_slice() {
249 [int_part, frac_part] => (*int_part, *frac_part),
250 [int_part] => (*int_part, ""),
251 _ => {
252 return Err(MathError::SdkDecError(format!(
253 "Invalid decimal format: {}",
254 abs_str
255 )))
256 }
257 };
258
259 if int_part.is_empty() || (parts.len() == 2 && frac_part.is_empty()) {
261 return Err(MathError::SdkDecError(format!(
262 "Expected decimal string but got: {}",
263 abs_str
264 )));
265 }
266
267 let mut sdk_dec = format!("{int_part}{frac_part}");
270
271 let precision_digits = DecimalExt::precision_digits();
273 if frac_part.len() > precision_digits {
274 return Err(MathError::SdkDecError(format!(
275 "Value exceeds max precision digits ({}): {}",
276 precision_digits, abs_str
277 )));
278 }
279 for _ in 0..(precision_digits - frac_part.len()) {
280 sdk_dec.push('0');
281 }
282
283 if neg {
285 sdk_dec.insert(0, '-');
286 }
287
288 Ok(sdk_dec)
289 }
290
291 pub fn from_sdk_dec(sdk_dec_str: &str) -> Result<DecimalExt, MathError> {
292 let precision_digits = DecimalExt::precision_digits();
293 if sdk_dec_str.is_empty() {
294 return Ok(DecimalExt::zero());
295 }
296
297 if sdk_dec_str.contains('.') {
298 return Err(MathError::SdkDecError(format!(
299 "Expected a decimal string but got '{}'",
300 sdk_dec_str
301 )));
302 }
303
304 let (neg, abs_str) =
306 if let Some(stripped) = sdk_dec_str.strip_prefix('-') {
307 (true, stripped)
308 } else {
309 (false, sdk_dec_str)
310 };
311
312 if abs_str.is_empty() || abs_str.chars().any(|c| !c.is_ascii_digit()) {
313 return Err(MathError::SdkDecError(format!(
314 "Invalid decimal format: {}",
315 sdk_dec_str
316 )));
317 }
318
319 let input_size = abs_str.len();
320 let mut decimal_str = String::new();
321
322 if input_size <= precision_digits {
323 decimal_str.push_str("0.");
325 decimal_str.push_str(&"0".repeat(precision_digits - input_size));
326 decimal_str.push_str(abs_str);
327 } else {
328 let dec_point_place = input_size - precision_digits;
330 decimal_str.push_str(&abs_str[..dec_point_place]);
331 decimal_str.push('.');
332 decimal_str.push_str(&abs_str[dec_point_place..]);
333 }
334
335 if neg {
336 decimal_str.insert(0, '-');
337 }
338
339 DecimalExt::from_str(&decimal_str).map_err(Into::into)
340 }
341}
342
343#[cfg(test)]
344mod test_sign_dec {
345 use cosmwasm_std as cw;
346 use std::str::FromStr;
347
348 use crate::{
349 errors::TestResult,
350 math::{DecimalExt, Sign},
351 };
352
353 #[test]
354 fn default_is_zero() -> TestResult {
355 assert_eq!(
356 DecimalExt::default(),
357 DecimalExt {
358 sign: Sign::Zero,
359 dec: cw::Decimal::from_str("0")?
360 }
361 );
362 assert_eq!(DecimalExt::default(), DecimalExt::zero());
363 assert_eq!(DecimalExt::zero(), cw::Decimal::from_str("0")?.into());
364 Ok(())
365 }
366
367 #[test]
368 fn from_cw() -> TestResult {
369 assert_eq!(
370 DecimalExt::default(),
371 DecimalExt::from(cw::Decimal::from_str("0")?)
372 );
373
374 let cw_dec = cw::Decimal::from_str("123.456")?;
375 assert_eq!(
376 DecimalExt {
377 sign: Sign::Positive,
378 dec: cw_dec
379 },
380 DecimalExt::from(cw_dec)
381 );
382
383 let num = "123.456";
384 assert_eq!(
385 DecimalExt {
386 sign: Sign::Negative,
387 dec: cw::Decimal::from_str(num)?
388 },
389 DecimalExt::from_str(&format!("-{}", num))?
390 );
391
392 Ok(())
393 }
394
395 #[test]
397 fn add() -> TestResult {
398 let test_cases: &[(&str, &str, &str)] = &[
399 ("0", "0", "0"),
400 ("0", "420", "420"),
401 ("69", "420", "489"),
402 ("5", "-3", "2"),
403 ("-7", "7", "0"),
404 ("-420", "69", "-351"),
405 ("-69", "420", "351"),
406 ];
407 for &(a, b, want_sum_of) in test_cases.iter() {
408 let a = DecimalExt::from_str(a)?;
409 let b = DecimalExt::from_str(b)?;
410 let want_sum_of = DecimalExt::from_str(want_sum_of)?;
411 let got_sum_of = a.add(b);
412 assert_eq!(want_sum_of, got_sum_of);
413 }
414 Ok(())
415 }
416
417 #[test]
418 fn neg() -> TestResult {
419 let pos_num = DecimalExt::from_str("69")?;
420 let neg_num = DecimalExt::from_str("-69")?;
421 let zero_num = DecimalExt::zero();
422
423 assert_eq!(neg_num, pos_num.neg());
424 assert_eq!(pos_num, neg_num.neg());
425 assert_eq!(zero_num, zero_num.neg());
426 Ok(())
427 }
428
429 #[test]
430 fn mul() -> TestResult {
431 let test_cases: &[(&str, &str, &str)] = &[
432 ("0", "0", "0"),
433 ("0", "420", "0"),
434 ("16", "16", "256"),
435 ("5", "-3", "-15"),
436 ("-7", "7", "-49"),
437 ];
438 for &(a, b, want_product) in test_cases.iter() {
439 let a = DecimalExt::from_str(a)?;
440 let b = DecimalExt::from_str(b)?;
441 let want_product = DecimalExt::from_str(want_product)?;
442 let got_product = a.mul(b);
443 assert_eq!(want_product, got_product);
444 }
445 Ok(())
446 }
447
448 #[test]
449 fn quo() -> TestResult {
450 let test_cases: &[(&str, &str, &str)] = &[
451 ("0", "420", "0"),
452 ("256", "16", "16"),
453 ("-15", "5", "-3"),
454 ("-49", "-7", "7"),
455 ];
456 for &(a, b, want_quo) in test_cases.iter() {
457 let a = DecimalExt::from_str(a)?;
458 let b = DecimalExt::from_str(b)?;
459 let want_quo = DecimalExt::from_str(want_quo)?;
460 let got_quo = a.quo(b)?;
461 assert_eq!(want_quo, got_quo);
462 }
463 Ok(())
464 }
465
466 #[test]
467 fn sdk_dec_int_only() -> TestResult {
468 let test_cases: &[(&str, &str)] = &[
469 ("0", &"0".repeat(18)),
471 ("000.00", &"0".repeat(18)),
472 ("0.00", &"0".repeat(18)),
473 ("00000", &"0".repeat(18)),
474 ("10", &format!("10{}", "0".repeat(18))),
476 ("-10", &format!("-10{}", "0".repeat(18))),
477 ("123", &format!("123{}", "0".repeat(18))),
478 ("-123", &format!("-123{}", "0".repeat(18))),
479 ];
480
481 for tc in test_cases.iter() {
482 let (arg, want_sdk_dec) = tc;
483 let want_dec: DecimalExt = DecimalExt::from_str(arg)?;
484 let got_sdk_dec: String = want_dec.to_sdk_dec_pb_repr()?;
485 assert_eq!(want_sdk_dec.to_owned(), got_sdk_dec);
486
487 let got_dec = DecimalExt::from_sdk_dec(&got_sdk_dec)?;
488 assert_eq!(want_dec, got_dec)
489 }
490 Ok(())
491 }
492
493 #[test]
495 fn sdk_dec_fractional() -> TestResult {
496 let test_cases: &[(&str, &str)] = &[
497 ("0.5", &format!("05{}", "0".repeat(17))),
498 ("0.005", &format!("0005{}", "0".repeat(15))),
499 ("123.456", &format!("123456{}", "0".repeat(15))),
500 ("-123.456", &format!("-123456{}", "0".repeat(15))),
501 ("0.00596", &format!("000596{}", "0".repeat(13))),
502 ("13.5", &format!("135{}", "0".repeat(17))),
503 ("-13.5", &format!("-135{}", "0".repeat(17))),
504 ("1574.00005", &format!("157400005{}", "0".repeat(13))),
505 ];
506
507 for tc in test_cases.iter() {
508 let (arg, want_sdk_dec) = tc;
509 let want_dec: DecimalExt = DecimalExt::from_str(arg)?;
510 let got_sdk_dec: String = want_dec.to_sdk_dec_pb_repr()?;
511 assert_eq!(want_sdk_dec.to_owned(), got_sdk_dec);
512
513 let got_dec = DecimalExt::from_sdk_dec(&got_sdk_dec)?;
514 assert_eq!(want_dec, got_dec)
515 }
516 Ok(())
517 }
518}