nibiru_std/
math.rs

1use crate::errors::MathError;
2use std::{
3    fmt,
4    ops::{Add, Div, Mul, Sub},
5    str::FromStr,
6};
7
8use cosmwasm_std as cw;
9
10// cosmwasm dec from sdk dec
11// TODO: cosmwasm dec from sdk int  -> What's the max value
12// Decimal
13
14/// Sign: The sign of a number. "Positive" and "Negative" mean strictly
15/// postive and negative (excluding 0), respectively.
16#[derive(Debug, Default, Copy, Clone, PartialEq, Eq, Ord, PartialOrd)]
17pub enum Sign {
18    Positive,
19    Negative,
20    #[default]
21    Zero,
22}
23
24/// DecimalExt: Implements a signed version of `cosmwasm_std::Decimal`
25/// with extentions for generating protobuf type strings.
26#[derive(Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Debug)]
27pub struct DecimalExt {
28    sign: Sign,
29    dec: cw::Decimal,
30}
31
32impl DecimalExt {
33    pub fn zero() -> Self {
34        DecimalExt::default()
35    }
36
37    /// Getter for `Sign`, which can be +, -, 0.
38    pub fn sign(&self) -> Sign {
39        self.sign
40    }
41
42    /// Getter for the underlying `cosmwasm_std::Decimal`.
43    pub fn abc_cw_dec(&self) -> cw::Decimal {
44        self.dec
45    }
46
47    pub fn add(&self, other: Self) -> Self {
48        if self.sign == other.sign {
49            return DecimalExt {
50                sign: self.sign,
51                dec: self.dec.add(other.dec),
52            };
53        } else if other.dec.is_zero() {
54            return *self;
55        }
56
57        let self_dec_gt: bool = self.dec.ge(&other.dec);
58        let sign = if self_dec_gt { self.sign } else { other.sign };
59        let dec = if self_dec_gt {
60            self.dec.sub(other.dec) // if abs(self.dec) > abs(other.dec)
61        } else {
62            other.dec.sub(self.dec) // if abs(self.dec) < abs(other.dec)
63        };
64        let sign = if dec.is_zero() { Sign::Zero } else { sign };
65
66        DecimalExt { sign, dec }
67    }
68
69    pub fn neg(&self) -> Self {
70        match self.sign {
71            Sign::Positive => DecimalExt {
72                sign: Sign::Negative,
73                dec: self.dec,
74            },
75            Sign::Negative => DecimalExt {
76                sign: Sign::Positive,
77                dec: self.dec,
78            },
79            Sign::Zero => *self,
80        }
81    }
82
83    pub fn sub(&self, other: Self) -> Self {
84        self.add(other.neg())
85    }
86
87    pub fn mul(&self, other: Self) -> Self {
88        let dec = self.dec.mul(other.dec);
89        let sign = match (self.sign, other.sign) {
90            (Sign::Zero, _) | (_, Sign::Zero) => Sign::Zero,
91            (Sign::Positive, Sign::Positive)
92            | (Sign::Negative, Sign::Negative) => Sign::Positive,
93            (Sign::Positive, Sign::Negative)
94            | (Sign::Negative, Sign::Positive) => Sign::Negative,
95        };
96        DecimalExt { sign, dec }
97    }
98
99    pub fn quo(&self, other: Self) -> Result<Self, MathError> {
100        let sign = match (self.sign, other.sign) {
101            (Sign::Zero, _) => Sign::Zero,
102            (_, Sign::Zero) => return Err(MathError::DivisionByZero),
103            (Sign::Positive, Sign::Positive)
104            | (Sign::Negative, Sign::Negative) => Sign::Positive,
105            (Sign::Positive, Sign::Negative)
106            | (Sign::Negative, Sign::Positive) => Sign::Negative,
107        };
108        let dec = self.dec.div(other.dec);
109        Ok(DecimalExt { sign, dec })
110    }
111}
112
113impl From<cw::Decimal> for DecimalExt {
114    fn from(cw_dec: cw::Decimal) -> Self {
115        if cw_dec.is_zero() {
116            return DecimalExt::zero();
117        }
118        DecimalExt {
119            sign: Sign::Positive,
120            dec: cw_dec,
121        }
122    }
123}
124
125impl FromStr for DecimalExt {
126    type Err = MathError;
127
128    /// Converts the decimal string to a `DecimalExt`
129    /// Possible inputs: "-69", "-420.69", "1.23", "1", "0012", "1.123000",
130    /// Disallowed: "", ".23"
131    fn from_str(s: &str) -> Result<Self, Self::Err> {
132        let non_strict_sign = if s.starts_with('-') {
133            Sign::Negative
134        } else {
135            Sign::Positive
136        };
137
138        let abs_value = if let Some(s) = s.strip_prefix('-') {
139            s // Strip the negative sign for parsing
140        } else {
141            s
142        };
143
144        let cw_dec: cw::Decimal =
145            cw::Decimal::from_str(abs_value).map_err(|cw_std_err| {
146                MathError::CwDecParseError {
147                    dec_str: s.to_string(),
148                    err: cw_std_err,
149                }
150            })?;
151        let sign = if cw_dec.is_zero() {
152            Sign::Zero
153        } else {
154            non_strict_sign
155        };
156        Ok(DecimalExt { sign, dec: cw_dec })
157    }
158}
159
160impl fmt::Display for DecimalExt {
161    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162        let prefix = if self.sign == Sign::Negative { "-" } else { "" };
163        write!(f, "{}{}", prefix, self.dec)
164    }
165}
166
167/// SdkDec: Decimal string representing the protobuf string for
168/// `"cosmossdk.io/math".LegacyDec`.
169/// See https://pkg.go.dev/cosmossdk.io/math@v1.2.0#LegacyDec.
170pub struct SdkDec {
171    protobuf_repr: String,
172}
173
174impl SdkDec {
175    pub fn new(dec: &DecimalExt) -> Result<Self, MathError> {
176        Ok(Self {
177            protobuf_repr: dec.to_sdk_dec_pb_repr()?,
178        })
179    }
180
181    /// Returns the protobuf representation.
182    pub fn pb_repr(&self) -> String {
183        self.protobuf_repr.to_string()
184    }
185
186    pub fn from_dec(dec: DecimalExt) -> Result<Self, MathError> {
187        Self::new(&dec)
188    }
189
190    pub fn from_cw_dec(cw_dec: cw::Decimal) -> Result<Self, MathError> {
191        Self::new(&DecimalExt::from(cw_dec))
192    }
193}
194
195impl FromStr for SdkDec {
196    type Err = MathError;
197
198    /// Converts the decimal string to an `SdkDec` compatible for use with
199    /// protobuf strings corresponding to `"cosmossdk.io/math".LegacyDec`
200    /// See https://pkg.go.dev/cosmossdk.io/math@v1.2.0#LegacyDec.
201    ///
202    /// Possible inputs: "-69", "-420.69", "1.23", "1", "0012", "1.123000",
203    /// Disallowed: "", ".23"
204    fn from_str(s: &str) -> Result<Self, Self::Err> {
205        Self::new(&DecimalExt::from_str(s)?)
206    }
207}
208
209impl fmt::Display for SdkDec {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        let dec =
212            DecimalExt::from_sdk_dec(&self.pb_repr()).unwrap_or_else(|err| {
213                panic!(
214                    "ParseError: could not marshal SdkDec {} to DecimalExt: {}",
215                    self.protobuf_repr, err,
216                )
217            });
218        write!(f, "{}", dec)
219    }
220}
221
222impl DecimalExt {
223    pub fn precision_digits() -> usize {
224        18
225    }
226
227    /// to_sdk_dec_pb_repr: Encodes the `DecimalExt` from the human readable
228    /// form to the corresponding SdkDec (`cosmossdk.io/math.LegacyDec`).
229    pub fn to_sdk_dec(&self) -> Result<SdkDec, MathError> {
230        SdkDec::new(self)
231    }
232
233    /// to_sdk_dec_pb_repr: Encodes the `DecimalExt` its SdkDec
234    /// (`cosmossdk.io/math.LegacyDec`) protobuf representation.
235    pub fn to_sdk_dec_pb_repr(&self) -> Result<String, MathError> {
236        if self.dec.is_zero() {
237            return Ok("0".repeat(DecimalExt::precision_digits()));
238        }
239
240        // Convert Decimal to string
241        let abs_str = self.dec.to_string();
242
243        // Handle negative sign
244        let neg = self.sign == Sign::Negative;
245
246        // Split into integer and fractional parts
247        let parts: Vec<&str> = abs_str.split('.').collect();
248        let (int_part, frac_part) = match parts.as_slice() {
249            [int_part, frac_part] => (*int_part, *frac_part),
250            [int_part] => (*int_part, ""),
251            _ => {
252                return Err(MathError::SdkDecError(format!(
253                    "Invalid decimal format: {}",
254                    abs_str
255                )))
256            }
257        };
258
259        // Check for valid number format
260        if int_part.is_empty() || (parts.len() == 2 && frac_part.is_empty()) {
261            return Err(MathError::SdkDecError(format!(
262                "Expected decimal string but got: {}",
263                abs_str
264            )));
265        }
266
267        // ----- Build the `sdk_dec` now that validation is complete. -----
268        // Concatenate integer and fractional parts
269        let mut sdk_dec = format!("{int_part}{frac_part}");
270
271        // Add trailing zeros to match precision
272        let precision_digits = DecimalExt::precision_digits();
273        if frac_part.len() > precision_digits {
274            return Err(MathError::SdkDecError(format!(
275                "Value exceeds max precision digits ({}): {}",
276                precision_digits, abs_str
277            )));
278        }
279        for _ in 0..(precision_digits - frac_part.len()) {
280            sdk_dec.push('0');
281        }
282
283        // Add negative sign if necessary
284        if neg {
285            sdk_dec.insert(0, '-');
286        }
287
288        Ok(sdk_dec)
289    }
290
291    pub fn from_sdk_dec(sdk_dec_str: &str) -> Result<DecimalExt, MathError> {
292        let precision_digits = DecimalExt::precision_digits();
293        if sdk_dec_str.is_empty() {
294            return Ok(DecimalExt::zero());
295        }
296
297        if sdk_dec_str.contains('.') {
298            return Err(MathError::SdkDecError(format!(
299                "Expected a decimal string but got '{}'",
300                sdk_dec_str
301            )));
302        }
303
304        // Check if negative and remove the '-' prefix if present
305        let (neg, abs_str) =
306            if let Some(stripped) = sdk_dec_str.strip_prefix('-') {
307                (true, stripped)
308            } else {
309                (false, sdk_dec_str)
310            };
311
312        if abs_str.is_empty() || abs_str.chars().any(|c| !c.is_ascii_digit()) {
313            return Err(MathError::SdkDecError(format!(
314                "Invalid decimal format: {}",
315                sdk_dec_str
316            )));
317        }
318
319        let input_size = abs_str.len();
320        let mut decimal_str = String::new();
321
322        if input_size <= precision_digits {
323            // Case 1: Purely decimal number
324            decimal_str.push_str("0.");
325            decimal_str.push_str(&"0".repeat(precision_digits - input_size));
326            decimal_str.push_str(abs_str);
327        } else {
328            // Case 2: Number has both integer and decimal parts
329            let dec_point_place = input_size - precision_digits;
330            decimal_str.push_str(&abs_str[..dec_point_place]);
331            decimal_str.push('.');
332            decimal_str.push_str(&abs_str[dec_point_place..]);
333        }
334
335        if neg {
336            decimal_str.insert(0, '-');
337        }
338
339        DecimalExt::from_str(&decimal_str).map_err(Into::into)
340    }
341}
342
343#[cfg(test)]
344mod test_sign_dec {
345    use cosmwasm_std as cw;
346    use std::str::FromStr;
347
348    use crate::{
349        errors::TestResult,
350        math::{DecimalExt, Sign},
351    };
352
353    #[test]
354    fn default_is_zero() -> TestResult {
355        assert_eq!(
356            DecimalExt::default(),
357            DecimalExt {
358                sign: Sign::Zero,
359                dec: cw::Decimal::from_str("0")?
360            }
361        );
362        assert_eq!(DecimalExt::default(), DecimalExt::zero());
363        assert_eq!(DecimalExt::zero(), cw::Decimal::from_str("0")?.into());
364        Ok(())
365    }
366
367    #[test]
368    fn from_cw() -> TestResult {
369        assert_eq!(
370            DecimalExt::default(),
371            DecimalExt::from(cw::Decimal::from_str("0")?)
372        );
373
374        let cw_dec = cw::Decimal::from_str("123.456")?;
375        assert_eq!(
376            DecimalExt {
377                sign: Sign::Positive,
378                dec: cw_dec
379            },
380            DecimalExt::from(cw_dec)
381        );
382
383        let num = "123.456";
384        assert_eq!(
385            DecimalExt {
386                sign: Sign::Negative,
387                dec: cw::Decimal::from_str(num)?
388            },
389            DecimalExt::from_str(&format!("-{}", num))?
390        );
391
392        Ok(())
393    }
394
395    // TODO: How will you handle overflow?
396    #[test]
397    fn add() -> TestResult {
398        let test_cases: &[(&str, &str, &str)] = &[
399            ("0", "0", "0"),
400            ("0", "420", "420"),
401            ("69", "420", "489"),
402            ("5", "-3", "2"),
403            ("-7", "7", "0"),
404            ("-420", "69", "-351"),
405            ("-69", "420", "351"),
406        ];
407        for &(a, b, want_sum_of) in test_cases.iter() {
408            let a = DecimalExt::from_str(a)?;
409            let b = DecimalExt::from_str(b)?;
410            let want_sum_of = DecimalExt::from_str(want_sum_of)?;
411            let got_sum_of = a.add(b);
412            assert_eq!(want_sum_of, got_sum_of);
413        }
414        Ok(())
415    }
416
417    #[test]
418    fn neg() -> TestResult {
419        let pos_num = DecimalExt::from_str("69")?;
420        let neg_num = DecimalExt::from_str("-69")?;
421        let zero_num = DecimalExt::zero();
422
423        assert_eq!(neg_num, pos_num.neg());
424        assert_eq!(pos_num, neg_num.neg());
425        assert_eq!(zero_num, zero_num.neg());
426        Ok(())
427    }
428
429    #[test]
430    fn mul() -> TestResult {
431        let test_cases: &[(&str, &str, &str)] = &[
432            ("0", "0", "0"),
433            ("0", "420", "0"),
434            ("16", "16", "256"),
435            ("5", "-3", "-15"),
436            ("-7", "7", "-49"),
437        ];
438        for &(a, b, want_product) in test_cases.iter() {
439            let a = DecimalExt::from_str(a)?;
440            let b = DecimalExt::from_str(b)?;
441            let want_product = DecimalExt::from_str(want_product)?;
442            let got_product = a.mul(b);
443            assert_eq!(want_product, got_product);
444        }
445        Ok(())
446    }
447
448    #[test]
449    fn quo() -> TestResult {
450        let test_cases: &[(&str, &str, &str)] = &[
451            ("0", "420", "0"),
452            ("256", "16", "16"),
453            ("-15", "5", "-3"),
454            ("-49", "-7", "7"),
455        ];
456        for &(a, b, want_quo) in test_cases.iter() {
457            let a = DecimalExt::from_str(a)?;
458            let b = DecimalExt::from_str(b)?;
459            let want_quo = DecimalExt::from_str(want_quo)?;
460            let got_quo = a.quo(b)?;
461            assert_eq!(want_quo, got_quo);
462        }
463        Ok(())
464    }
465
466    #[test]
467    fn sdk_dec_int_only() -> TestResult {
468        let test_cases: &[(&str, &str)] = &[
469            // Zero cases should all be equal
470            ("0", &"0".repeat(18)),
471            ("000.00", &"0".repeat(18)),
472            ("0.00", &"0".repeat(18)),
473            ("00000", &"0".repeat(18)),
474            // Non-zero cases
475            ("10", &format!("10{}", "0".repeat(18))),
476            ("-10", &format!("-10{}", "0".repeat(18))),
477            ("123", &format!("123{}", "0".repeat(18))),
478            ("-123", &format!("-123{}", "0".repeat(18))),
479        ];
480
481        for tc in test_cases.iter() {
482            let (arg, want_sdk_dec) = tc;
483            let want_dec: DecimalExt = DecimalExt::from_str(arg)?;
484            let got_sdk_dec: String = want_dec.to_sdk_dec_pb_repr()?;
485            assert_eq!(want_sdk_dec.to_owned(), got_sdk_dec);
486
487            let got_dec = DecimalExt::from_sdk_dec(&got_sdk_dec)?;
488            assert_eq!(want_dec, got_dec)
489        }
490        Ok(())
491    }
492
493    /// to_sdk_dec test with fractional parts
494    #[test]
495    fn sdk_dec_fractional() -> TestResult {
496        let test_cases: &[(&str, &str)] = &[
497            ("0.5", &format!("05{}", "0".repeat(17))),
498            ("0.005", &format!("0005{}", "0".repeat(15))),
499            ("123.456", &format!("123456{}", "0".repeat(15))),
500            ("-123.456", &format!("-123456{}", "0".repeat(15))),
501            ("0.00596", &format!("000596{}", "0".repeat(13))),
502            ("13.5", &format!("135{}", "0".repeat(17))),
503            ("-13.5", &format!("-135{}", "0".repeat(17))),
504            ("1574.00005", &format!("157400005{}", "0".repeat(13))),
505        ];
506
507        for tc in test_cases.iter() {
508            let (arg, want_sdk_dec) = tc;
509            let want_dec: DecimalExt = DecimalExt::from_str(arg)?;
510            let got_sdk_dec: String = want_dec.to_sdk_dec_pb_repr()?;
511            assert_eq!(want_sdk_dec.to_owned(), got_sdk_dec);
512
513            let got_dec = DecimalExt::from_sdk_dec(&got_sdk_dec)?;
514            assert_eq!(want_dec, got_dec)
515        }
516        Ok(())
517    }
518}