Skip to main content

krusty_kms_client/
types.rs

1//! Types for interacting with TONGO contracts on Starknet.
2
3use krusty_kms_common::Result;
4use starknet_types_core::curve::ProjectivePoint;
5use starknet_types_core::felt::Felt;
6
7/// Cipher balance stored on-chain (ElGamal ciphertext).
8#[derive(Debug, Clone)]
9pub struct CipherBalance {
10    pub l: ProjectivePoint,
11    pub r: ProjectivePoint,
12}
13
14/// Account state from the TONGO contract.
15#[derive(Debug, Clone)]
16pub struct AccountState {
17    /// Current spendable balance (encrypted)
18    pub balance: CipherBalance,
19    /// Pending balance from transfers (encrypted)
20    pub pending: CipherBalance,
21    /// Account nonce
22    pub nonce: Felt,
23}
24
25/// Decrypted account state.
26#[derive(Debug, Clone)]
27pub struct DecryptedAccountState {
28    /// Current spendable balance (plaintext)
29    pub balance: u128,
30    /// Pending balance from transfers (plaintext)
31    pub pending: u128,
32    /// Account nonce
33    pub nonce: Felt,
34}
35
36/// Decrypt a cipher balance using ElGamal decryption.
37///
38/// Given C = (L, R) = (g^m * y^r, g^r), where:
39/// - g is the generator
40/// - m is the message (balance)
41/// - y is the public key
42/// - r is the random nonce
43///
44/// We can decrypt by computing: m = L / R^x, where x is the private key.
45///
46/// # Cyclomatic Complexity: 3
47pub fn decrypt_cipher_balance(private_key: &Felt, cipher: &CipherBalance) -> Result<u128> {
48    // Calculate R^x (scalar multiplication)
49    let r_x = multiply_point(&cipher.r, private_key)?;
50
51    // Calculate L - R^x to get g^m
52    let g_m = subtract_points(&cipher.l, &r_x)?;
53
54    // Perform discrete log to recover m
55    // For small values (typical balances), we use brute force
56    let balance = discrete_log_brute_force(&g_m)?;
57
58    Ok(balance)
59}
60
61/// Multiply a point by a scalar.
62fn multiply_point(point: &ProjectivePoint, scalar: &Felt) -> Result<ProjectivePoint> {
63    // Perform scalar multiplication using double-and-add
64    let scalar_bytes = scalar.to_bytes_be();
65    let mut result: Option<ProjectivePoint> = None;
66    let mut temp = point.clone();
67
68    for byte in scalar_bytes.iter().rev() {
69        for i in 0..8 {
70            if (byte >> i) & 1 == 1 {
71                result = Some(match result {
72                    Some(r) => &r + &temp,
73                    None => temp.clone(),
74                });
75            }
76            temp = &temp + &temp;
77        }
78    }
79
80    Ok(result.unwrap_or(point.clone()))
81}
82
83/// Subtract two points (add the inverse).
84fn subtract_points(a: &ProjectivePoint, b: &ProjectivePoint) -> Result<ProjectivePoint> {
85    // Negate b by negating the y-coordinate
86    let b_affine = b.to_affine().map_err(|_| {
87        krusty_kms_common::KmsError::CryptoError("Invalid point (identity)".to_string())
88    })?;
89
90    let neg_y = Felt::ZERO - b_affine.y();
91    let neg_b = ProjectivePoint::from_affine(b_affine.x(), neg_y).map_err(|_| {
92        krusty_kms_common::KmsError::CryptoError("Invalid negated point".to_string())
93    })?;
94
95    Ok(a + &neg_b)
96}
97
98/// Recover the discrete log m from g^m using brute force.
99///
100/// This works for small values (up to ~10^12), which is sufficient for
101/// typical TONGO balances.
102///
103/// # Cyclomatic Complexity: 3
104fn discrete_log_brute_force(g_m: &ProjectivePoint) -> Result<u128> {
105    // Use the standard Stark curve generator from krusty-kms-crypto
106    let generator = krusty_kms_crypto::StarkCurve::generator();
107
108    // Try to convert to affine - if it fails, it's the identity (balance = 0)
109    if g_m.to_affine().is_err() {
110        return Ok(0);
111    }
112
113    // Brute force search up to MAX_SEARCH
114    const MAX_SEARCH: u128 = 1_000_000_000_000; // 1 trillion
115    let mut current = generator.clone();
116
117    for i in 1..=MAX_SEARCH {
118        if points_equal(&current, g_m) {
119            return Ok(i);
120        }
121        current = &current + &generator;
122
123        // Early exit if we've gone past reasonable balance values
124        if i > 1_000_000 && i % 1_000_000 == 0 {
125            // Check every million after the first million
126        }
127    }
128
129    Err(krusty_kms_common::KmsError::CryptoError(
130        "Failed to recover balance (discrete log not found within search limit)".to_string(),
131    ))
132}
133
134/// Check if two points are equal.
135fn points_equal(a: &ProjectivePoint, b: &ProjectivePoint) -> bool {
136    match (a.to_affine(), b.to_affine()) {
137        (Ok(a_aff), Ok(b_aff)) => a_aff.x() == b_aff.x() && a_aff.y() == b_aff.y(),
138        (Err(_), Err(_)) => true, // Both are identity
139        _ => false,
140    }
141}
142
143/// Convert ERC-20 amount to Tongo units (ceiling division by rate).
144pub fn erc20_to_tongo(erc20_amount: u128, rate: u128) -> u128 {
145    erc20_amount.div_ceil(rate)
146}
147
148/// Convert Tongo amount to ERC-20 units.
149pub fn tongo_to_erc20(tongo_amount: u128, rate: u128) -> u128 {
150    tongo_amount * rate
151}
152
153/// Authenticated encryption balance (raw on-chain representation).
154#[derive(Debug, Clone)]
155pub struct AEBalance {
156    pub ciphertext: [u8; 64],
157    pub nonce: [u8; 24],
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    #[test]
165    fn test_decrypt_zero_balance() {
166        // Create a cipher balance for 0: (y^r, g^r)
167        let private_key = Felt::from(12345u64);
168
169        // Generate public key
170        let g_x =
171            Felt::from_hex("0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca")
172                .unwrap();
173        let g_y =
174            Felt::from_hex("0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f")
175                .unwrap();
176        let generator = ProjectivePoint::from_affine(g_x, g_y).unwrap();
177
178        let public_key = multiply_point(&generator, &private_key).unwrap();
179
180        // Encrypt 0: C = (y^r, g^r) for some random r
181        let r = Felt::from(999u64);
182        let r_point = multiply_point(&generator, &r).unwrap();
183        let y_r = multiply_point(&public_key, &r).unwrap();
184
185        let cipher = CipherBalance { l: y_r, r: r_point };
186
187        let decrypted = decrypt_cipher_balance(&private_key, &cipher).unwrap();
188        assert_eq!(decrypted, 0);
189    }
190
191    #[test]
192    fn test_point_subtraction() {
193        let g_x =
194            Felt::from_hex("0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca")
195                .unwrap();
196        let g_y =
197            Felt::from_hex("0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f")
198                .unwrap();
199        let g = ProjectivePoint::from_affine(g_x, g_y).unwrap();
200
201        // g - g should give identity
202        let result = subtract_points(&g, &g).unwrap();
203        assert!(result.to_affine().is_err());
204    }
205
206    #[test]
207    fn test_points_equal_same_point() {
208        let g_x =
209            Felt::from_hex("0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca")
210                .unwrap();
211        let g_y =
212            Felt::from_hex("0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f")
213                .unwrap();
214        let g = ProjectivePoint::from_affine(g_x, g_y).unwrap();
215
216        assert!(points_equal(&g, &g));
217    }
218
219    #[test]
220    fn test_points_equal_different_points() {
221        let g_x =
222            Felt::from_hex("0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca")
223                .unwrap();
224        let g_y =
225            Felt::from_hex("0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f")
226                .unwrap();
227        let g = ProjectivePoint::from_affine(g_x, g_y).unwrap();
228
229        let g2 = &g + &g;
230        assert!(!points_equal(&g, &g2));
231    }
232
233    #[test]
234    fn test_points_equal_both_identity() {
235        let g_x =
236            Felt::from_hex("0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca")
237                .unwrap();
238        let g_y =
239            Felt::from_hex("0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f")
240                .unwrap();
241        let g = ProjectivePoint::from_affine(g_x, g_y).unwrap();
242
243        // g - g = identity
244        let id1 = subtract_points(&g, &g).unwrap();
245        let id2 = subtract_points(&g, &g).unwrap();
246
247        assert!(points_equal(&id1, &id2));
248    }
249
250    #[test]
251    fn test_points_equal_one_identity() {
252        let g_x =
253            Felt::from_hex("0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca")
254                .unwrap();
255        let g_y =
256            Felt::from_hex("0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f")
257                .unwrap();
258        let g = ProjectivePoint::from_affine(g_x, g_y).unwrap();
259
260        // g - g = identity
261        let identity = subtract_points(&g, &g).unwrap();
262
263        // One identity, one not
264        assert!(!points_equal(&g, &identity));
265        assert!(!points_equal(&identity, &g));
266    }
267
268    #[test]
269    fn test_multiply_point_by_one() {
270        let g_x =
271            Felt::from_hex("0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca")
272                .unwrap();
273        let g_y =
274            Felt::from_hex("0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f")
275                .unwrap();
276        let g = ProjectivePoint::from_affine(g_x, g_y).unwrap();
277
278        let result = multiply_point(&g, &Felt::ONE).unwrap();
279        assert!(points_equal(&result, &g));
280    }
281
282    #[test]
283    fn test_multiply_point_by_two() {
284        let g_x =
285            Felt::from_hex("0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca")
286                .unwrap();
287        let g_y =
288            Felt::from_hex("0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f")
289                .unwrap();
290        let g = ProjectivePoint::from_affine(g_x, g_y).unwrap();
291
292        let result = multiply_point(&g, &Felt::TWO).unwrap();
293        let expected = &g + &g;
294        assert!(points_equal(&result, &expected));
295    }
296
297    #[test]
298    fn test_discrete_log_small_value() {
299        let g_x =
300            Felt::from_hex("0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca")
301                .unwrap();
302        let g_y =
303            Felt::from_hex("0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f")
304                .unwrap();
305        let g = ProjectivePoint::from_affine(g_x, g_y).unwrap();
306
307        // Test discrete log for small value (1)
308        let result = discrete_log_brute_force(&g).unwrap();
309        assert_eq!(result, 1);
310    }
311
312    #[test]
313    fn test_discrete_log_value_5() {
314        let g_x =
315            Felt::from_hex("0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca")
316                .unwrap();
317        let g_y =
318            Felt::from_hex("0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f")
319                .unwrap();
320        let g = ProjectivePoint::from_affine(g_x, g_y).unwrap();
321
322        // Compute 5*g
323        let five_g = multiply_point(&g, &Felt::from(5u64)).unwrap();
324
325        let result = discrete_log_brute_force(&five_g).unwrap();
326        assert_eq!(result, 5);
327    }
328
329    #[test]
330    fn test_decrypt_small_balance() {
331        let private_key = Felt::from(12345u64);
332
333        // Use the generator
334        let g_x =
335            Felt::from_hex("0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca")
336                .unwrap();
337        let g_y =
338            Felt::from_hex("0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f")
339                .unwrap();
340        let generator = ProjectivePoint::from_affine(g_x, g_y).unwrap();
341
342        // Compute public key y = g^x
343        let public_key = multiply_point(&generator, &private_key).unwrap();
344
345        // Encrypt balance 5: C = (g^5 * y^r, g^r)
346        let r = Felt::from(999u64);
347        let r_point = multiply_point(&generator, &r).unwrap(); // g^r
348        let y_r = multiply_point(&public_key, &r).unwrap(); // y^r
349        let g_m = multiply_point(&generator, &Felt::from(5u64)).unwrap(); // g^5
350        let l = &g_m + &y_r; // g^5 * y^r
351
352        let cipher = CipherBalance { l, r: r_point };
353
354        let decrypted = decrypt_cipher_balance(&private_key, &cipher).unwrap();
355        assert_eq!(decrypted, 5);
356    }
357
358    #[test]
359    fn test_erc20_to_tongo_exact() {
360        assert_eq!(erc20_to_tongo(1000, 10), 100);
361    }
362
363    #[test]
364    fn test_erc20_to_tongo_ceiling() {
365        assert_eq!(erc20_to_tongo(1001, 10), 101);
366        assert_eq!(erc20_to_tongo(1009, 10), 101);
367    }
368
369    #[test]
370    fn test_erc20_to_tongo_rate_one() {
371        assert_eq!(erc20_to_tongo(42, 1), 42);
372    }
373
374    #[test]
375    fn test_erc20_to_tongo_rate_greater_than_amount() {
376        assert_eq!(erc20_to_tongo(5, 100), 1);
377    }
378
379    #[test]
380    fn test_erc20_to_tongo_zero_amount() {
381        assert_eq!(erc20_to_tongo(0, 10), 0);
382    }
383
384    #[test]
385    fn test_tongo_to_erc20_basic() {
386        assert_eq!(tongo_to_erc20(100, 10), 1000);
387    }
388
389    #[test]
390    fn test_tongo_to_erc20_rate_one() {
391        assert_eq!(tongo_to_erc20(42, 1), 42);
392    }
393
394    #[test]
395    fn test_tongo_to_erc20_zero_amount() {
396        assert_eq!(tongo_to_erc20(0, 10), 0);
397    }
398
399    #[test]
400    fn test_roundtrip_conversion() {
401        let rate = 1000u128;
402        let original_tongo = 50u128;
403        let erc20 = tongo_to_erc20(original_tongo, rate);
404        let back = erc20_to_tongo(erc20, rate);
405        assert_eq!(back, original_tongo);
406    }
407}