Skip to main content

krusty_kms_common/
amount.rs

1//! Precision-safe token amount representation.
2
3use serde::{Deserialize, Serialize};
4use starknet_types_core::felt::Felt;
5use std::fmt;
6
7use crate::{KmsError, Result};
8
9/// A token amount with associated decimal precision.
10///
11/// Stores the raw (smallest-unit) value plus the number of decimals,
12/// avoiding floating-point imprecision.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
14pub struct Amount {
15    raw: u128,
16    decimals: u8,
17}
18
19impl Amount {
20    /// Create from the raw (smallest-unit) integer value.
21    pub fn from_raw(raw: u128, decimals: u8) -> Self {
22        Self { raw, decimals }
23    }
24
25    /// Parse a human-readable decimal string (e.g. `"1.5"`) into an `Amount`.
26    pub fn from_human(s: &str, decimals: u8) -> Result<Self> {
27        let s = s.trim();
28        let factor = 10u128
29            .checked_pow(decimals as u32)
30            .ok_or_else(|| KmsError::InvalidAmount("decimals too large".into()))?;
31
32        let raw = if let Some(dot) = s.find('.') {
33            let integer_part = &s[..dot];
34            let frac_part = &s[dot + 1..];
35
36            if frac_part.len() > decimals as usize {
37                return Err(KmsError::InvalidAmount(format!(
38                    "too many decimal places (max {})",
39                    decimals
40                )));
41            }
42
43            let int_val: u128 = if integer_part.is_empty() {
44                0
45            } else {
46                integer_part
47                    .parse()
48                    .map_err(|_| KmsError::InvalidAmount(format!("invalid number: {}", s)))?
49            };
50
51            let padded_frac = format!("{:0<width$}", frac_part, width = decimals as usize);
52            let frac_val: u128 = padded_frac
53                .parse()
54                .map_err(|_| KmsError::InvalidAmount(format!("invalid fraction: {}", s)))?;
55
56            int_val
57                .checked_mul(factor)
58                .and_then(|v| v.checked_add(frac_val))
59                .ok_or_else(|| KmsError::InvalidAmount("overflow".into()))?
60        } else {
61            let int_val: u128 = s
62                .parse()
63                .map_err(|_| KmsError::InvalidAmount(format!("invalid number: {}", s)))?;
64            int_val
65                .checked_mul(factor)
66                .ok_or_else(|| KmsError::InvalidAmount("overflow".into()))?
67        };
68
69        Ok(Self { raw, decimals })
70    }
71
72    /// The raw (smallest-unit) value.
73    pub fn raw(&self) -> u128 {
74        self.raw
75    }
76
77    /// Number of decimals.
78    pub fn decimals(&self) -> u8 {
79        self.decimals
80    }
81
82    /// Convert to a human-readable decimal string.
83    pub fn to_human(&self) -> String {
84        if self.decimals == 0 {
85            return self.raw.to_string();
86        }
87        let factor = 10u128.pow(self.decimals as u32);
88        let integer = self.raw / factor;
89        let fraction = self.raw % factor;
90        if fraction == 0 {
91            format!("{}.0", integer)
92        } else {
93            let frac_str = format!("{:0>width$}", fraction, width = self.decimals as usize);
94            let trimmed = frac_str.trim_end_matches('0');
95            format!("{}.{}", integer, trimmed)
96        }
97    }
98
99    /// Encode as a Starknet u256 `(low, high)` Felt pair.
100    ///
101    /// Since the raw value is `u128`, it fits entirely in the low limb.
102    pub fn to_u256(&self) -> (Felt, Felt) {
103        (Felt::from(self.raw), Felt::ZERO)
104    }
105
106    /// Checked addition (decimals must match).
107    pub fn checked_add(&self, other: &Amount) -> Result<Amount> {
108        if self.decimals != other.decimals {
109            return Err(KmsError::InvalidAmount(
110                "cannot add amounts with different decimals".into(),
111            ));
112        }
113        let raw = self
114            .raw
115            .checked_add(other.raw)
116            .ok_or_else(|| KmsError::InvalidAmount("overflow".into()))?;
117        Ok(Amount {
118            raw,
119            decimals: self.decimals,
120        })
121    }
122
123    /// Checked subtraction (decimals must match).
124    pub fn checked_sub(&self, other: &Amount) -> Result<Amount> {
125        if self.decimals != other.decimals {
126            return Err(KmsError::InvalidAmount(
127                "cannot subtract amounts with different decimals".into(),
128            ));
129        }
130        let raw = self
131            .raw
132            .checked_sub(other.raw)
133            .ok_or_else(|| KmsError::InvalidAmount("underflow".into()))?;
134        Ok(Amount {
135            raw,
136            decimals: self.decimals,
137        })
138    }
139
140    /// Returns true if the amount is zero.
141    pub fn is_zero(&self) -> bool {
142        self.raw == 0
143    }
144}
145
146impl fmt::Display for Amount {
147    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
148        f.write_str(&self.to_human())
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_from_raw() {
158        let amt = Amount::from_raw(1_500_000_000_000_000_000, 18);
159        assert_eq!(amt.to_human(), "1.5");
160    }
161
162    #[test]
163    fn test_from_human_integer() {
164        let amt = Amount::from_human("100", 18).unwrap();
165        assert_eq!(amt.raw(), 100_000_000_000_000_000_000);
166    }
167
168    #[test]
169    fn test_from_human_decimal() {
170        let amt = Amount::from_human("1.5", 18).unwrap();
171        assert_eq!(amt.raw(), 1_500_000_000_000_000_000);
172    }
173
174    #[test]
175    fn test_from_human_leading_dot() {
176        let amt = Amount::from_human(".5", 18).unwrap();
177        assert_eq!(amt.raw(), 500_000_000_000_000_000);
178    }
179
180    #[test]
181    fn test_from_human_too_many_decimals() {
182        assert!(Amount::from_human("1.1234567", 6).is_err());
183    }
184
185    #[test]
186    fn test_to_human_whole() {
187        let amt = Amount::from_raw(2_000_000, 6);
188        assert_eq!(amt.to_human(), "2.0");
189    }
190
191    #[test]
192    fn test_to_u256() {
193        let amt = Amount::from_raw(1000, 6);
194        let (low, high) = amt.to_u256();
195        assert_eq!(low, Felt::from(1000u64));
196        assert_eq!(high, Felt::ZERO);
197    }
198
199    #[test]
200    fn test_checked_add() {
201        let a = Amount::from_raw(100, 6);
202        let b = Amount::from_raw(200, 6);
203        let c = a.checked_add(&b).unwrap();
204        assert_eq!(c.raw(), 300);
205    }
206
207    #[test]
208    fn test_checked_add_different_decimals() {
209        let a = Amount::from_raw(100, 6);
210        let b = Amount::from_raw(200, 18);
211        assert!(a.checked_add(&b).is_err());
212    }
213
214    #[test]
215    fn test_checked_sub() {
216        let a = Amount::from_raw(300, 6);
217        let b = Amount::from_raw(100, 6);
218        let c = a.checked_sub(&b).unwrap();
219        assert_eq!(c.raw(), 200);
220    }
221
222    #[test]
223    fn test_checked_sub_underflow() {
224        let a = Amount::from_raw(100, 6);
225        let b = Amount::from_raw(200, 6);
226        assert!(a.checked_sub(&b).is_err());
227    }
228
229    #[test]
230    fn test_is_zero() {
231        assert!(Amount::from_raw(0, 18).is_zero());
232        assert!(!Amount::from_raw(1, 18).is_zero());
233    }
234
235    #[test]
236    fn test_display() {
237        let amt = Amount::from_raw(1_500_000, 6);
238        assert_eq!(format!("{}", amt), "1.5");
239    }
240
241    #[test]
242    fn test_zero_decimals() {
243        let amt = Amount::from_raw(42, 0);
244        assert_eq!(amt.to_human(), "42");
245    }
246}