1use crate::{ArithmeticError, Decimal, RoundingMode};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum OracleDecimals {
17 Six,
19 Eight,
21 Eighteen,
23 Custom(u8),
25}
26
27impl OracleDecimals {
28 pub const fn value(self) -> u8 {
30 match self {
31 Self::Six => 6,
32 Self::Eight => 8,
33 Self::Eighteen => 18,
34 Self::Custom(n) => n,
35 }
36 }
37
38 pub fn scale_factor(self) -> Decimal {
40 let decimals = self.value();
41 Decimal::from(10i64).powi(decimals as i32).unwrap_or(Decimal::MAX)
42 }
43}
44
45impl From<u8> for OracleDecimals {
46 fn from(n: u8) -> Self {
47 match n {
48 6 => Self::Six,
49 8 => Self::Eight,
50 18 => Self::Eighteen,
51 _ => Self::Custom(n),
52 }
53 }
54}
55
56pub fn normalize_oracle_price(
72 raw_value: i64,
73 decimals: OracleDecimals,
74) -> Result<Decimal, ArithmeticError> {
75 let scale = decimals.scale_factor();
76 Decimal::from(raw_value)
77 .checked_div(scale)
78 .ok_or(ArithmeticError::DivisionByZero)
79}
80
81pub fn normalize_oracle_price_i128(
86 raw_value: i128,
87 decimals: OracleDecimals,
88) -> Result<Decimal, ArithmeticError> {
89 let scale = decimals.scale_factor();
90 Decimal::try_from_i128(raw_value)?
91 .checked_div(scale)
92 .ok_or(ArithmeticError::DivisionByZero)
93}
94
95pub fn denormalize_oracle_price(
112 value: Decimal,
113 decimals: OracleDecimals,
114) -> Result<i64, ArithmeticError> {
115 let scale = decimals.scale_factor();
116 let scaled = value
117 .checked_mul(scale)
118 .ok_or(ArithmeticError::Overflow)?
119 .round(0, RoundingMode::TowardZero);
120 let (mantissa, _) = scaled.to_parts();
121 i64::try_from(mantissa).map_err(|_| ArithmeticError::Overflow)
122}
123
124pub fn denormalize_oracle_price_i128(
129 value: Decimal,
130 decimals: OracleDecimals,
131) -> Result<i128, ArithmeticError> {
132 let scale = decimals.scale_factor();
133 let scaled = value
134 .checked_mul(scale)
135 .ok_or(ArithmeticError::Overflow)?
136 .round(0, RoundingMode::TowardZero);
137 let (mantissa, _) = scaled.to_parts();
138 Ok(mantissa)
139}
140
141pub fn convert_decimals(
158 value: i64,
159 from: OracleDecimals,
160 to: OracleDecimals,
161) -> Result<i64, ArithmeticError> {
162 let from_decimals = from.value() as i32;
163 let to_decimals = to.value() as i32;
164 let diff = to_decimals - from_decimals;
165
166 if diff == 0 {
167 return Ok(value);
168 }
169
170 let factor = 10i64
171 .checked_pow(diff.unsigned_abs())
172 .ok_or(ArithmeticError::Overflow)?;
173
174 if diff > 0 {
175 value.checked_mul(factor).ok_or(ArithmeticError::Overflow)
176 } else {
177 Ok(value / factor)
178 }
179}
180
181pub fn convert_decimals_i128(
200 value: i64,
201 from: OracleDecimals,
202 to: OracleDecimals,
203) -> Result<i128, ArithmeticError> {
204 let from_decimals = from.value() as i32;
205 let to_decimals = to.value() as i32;
206 let diff = to_decimals - from_decimals;
207
208 if diff == 0 {
209 return Ok(value as i128);
210 }
211
212 let factor = 10i128
213 .checked_pow(diff.unsigned_abs())
214 .ok_or(ArithmeticError::Overflow)?;
215
216 if diff > 0 {
217 (value as i128)
218 .checked_mul(factor)
219 .ok_or(ArithmeticError::Overflow)
220 } else {
221 Ok((value as i128) / factor)
222 }
223}
224
225pub fn scale_token_amount(
245 amount: i64,
246 from_decimals: OracleDecimals,
247 to_decimals: OracleDecimals,
248) -> Result<i64, ArithmeticError> {
249 convert_decimals(amount, from_decimals, to_decimals)
250}
251
252pub fn scale_token_amount_i128(
269 amount: i64,
270 from_decimals: OracleDecimals,
271 to_decimals: OracleDecimals,
272) -> Result<i128, ArithmeticError> {
273 convert_decimals_i128(amount, from_decimals, to_decimals)
274}
275
276pub fn calculate_value(
309 amount: i64,
310 amount_decimals: OracleDecimals,
311 price: i64,
312 price_decimals: OracleDecimals,
313 result_decimals: OracleDecimals,
314) -> Result<i64, ArithmeticError> {
315 let amount_dec = normalize_oracle_price(amount, amount_decimals)?;
316 let price_dec = normalize_oracle_price(price, price_decimals)?;
317
318 let value = amount_dec
319 .checked_mul(price_dec)
320 .ok_or(ArithmeticError::Overflow)?;
321
322 denormalize_oracle_price(value, result_decimals)
323}
324
325pub fn calculate_value_i128(
327 amount: i64,
328 amount_decimals: OracleDecimals,
329 price: i64,
330 price_decimals: OracleDecimals,
331 result_decimals: OracleDecimals,
332) -> Result<i128, ArithmeticError> {
333 let amount_dec = normalize_oracle_price(amount, amount_decimals)?;
334 let price_dec = normalize_oracle_price(price, price_decimals)?;
335
336 let value = amount_dec
337 .checked_mul(price_dec)
338 .ok_or(ArithmeticError::Overflow)?;
339
340 denormalize_oracle_price_i128(value, result_decimals)
341}
342
343pub fn normalize_pyth_price(price: i64, exponent: i32) -> Result<Decimal, ArithmeticError> {
360 let price_dec = Decimal::from(price);
361
362 if exponent == 0 {
363 return Ok(price_dec);
364 }
365
366 let scale = Decimal::from(10i64)
367 .powi(exponent.abs())
368 .ok_or(ArithmeticError::Overflow)?;
369
370 if exponent > 0 {
371 price_dec.checked_mul(scale).ok_or(ArithmeticError::Overflow)
372 } else {
373 price_dec.checked_div(scale).ok_or(ArithmeticError::DivisionByZero)
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 extern crate alloc;
380
381 use super::*;
382 use alloc::string::ToString;
383 use core::str::FromStr;
384
385 #[test]
386 fn test_normalize_chainlink_price() {
387 let raw = 250012345678i64;
388 let price = normalize_oracle_price(raw, OracleDecimals::Eight).unwrap();
389 assert_eq!(price.to_string(), "2500.12345678");
390 }
391
392 #[test]
393 fn test_denormalize_price() {
394 let price = Decimal::from_str("2500.12345678").unwrap();
395 let raw = denormalize_oracle_price(price, OracleDecimals::Eight).unwrap();
396 assert_eq!(raw, 250012345678);
397 }
398
399 #[test]
400 fn test_convert_8_to_6_decimals() {
401 let chainlink = 250012345678i64;
402 let usdc = convert_decimals(chainlink, OracleDecimals::Eight, OracleDecimals::Six).unwrap();
403 assert_eq!(usdc, 2500123456);
404 }
405
406 #[test]
407 fn test_convert_8_to_18_decimals_i128() {
408 let chainlink = 250012345678i64;
409 let onchain =
410 convert_decimals_i128(chainlink, OracleDecimals::Eight, OracleDecimals::Eighteen)
411 .unwrap();
412 assert_eq!(onchain, 2500123456780000000000i128);
413 }
414
415 #[test]
416 fn test_convert_18_to_8_decimals_via_normalize() {
417 let original = 250012345678i64;
419 let normalized = normalize_oracle_price(original, OracleDecimals::Eight).unwrap();
420 let recovered = denormalize_oracle_price(normalized, OracleDecimals::Eight).unwrap();
421 assert_eq!(recovered, original);
422 }
423
424 #[test]
425 fn test_scale_usdc_to_8_decimals() {
426 let usdc = 1000_000_000i64; let scaled =
428 scale_token_amount(usdc, OracleDecimals::Six, OracleDecimals::Eight).unwrap();
429 assert_eq!(scaled, 100_000_000_000);
430 }
431
432 #[test]
433 fn test_scale_usdc_to_18_decimals_i128() {
434 let usdc = 1000_000_000i64; let scaled =
436 scale_token_amount_i128(usdc, OracleDecimals::Six, OracleDecimals::Eighteen).unwrap();
437 assert_eq!(scaled, 1000_000_000_000_000_000_000i128);
438 }
439
440 #[test]
441 fn test_pyth_positive_exponent() {
442 let price = normalize_pyth_price(25, 2).unwrap();
443 assert_eq!(price.to_string(), "2500");
444 }
445
446 #[test]
447 fn test_pyth_negative_exponent() {
448 let price = normalize_pyth_price(250012345678, -8).unwrap();
449 assert_eq!(price.to_string(), "2500.12345678");
450 }
451
452 #[test]
453 fn test_pyth_zero_exponent() {
454 let price = normalize_pyth_price(2500, 0).unwrap();
455 assert_eq!(price.to_string(), "2500");
456 }
457
458 #[test]
459 fn test_calculate_usdc_value() {
460 let usdc_amount = 1_000_000_000i64; let usdc_price = 100000000i64; let value = calculate_value(
465 usdc_amount,
466 OracleDecimals::Six,
467 usdc_price,
468 OracleDecimals::Eight,
469 OracleDecimals::Six,
470 )
471 .unwrap();
472
473 assert_eq!(value, 1_000_000_000); }
475
476 #[test]
477 fn test_calculate_btc_value() {
478 let btc_amount = 10_000_000i64; let btc_price = 5000000000000i64; let value = calculate_value(
484 btc_amount,
485 OracleDecimals::Eight,
486 btc_price,
487 OracleDecimals::Eight,
488 OracleDecimals::Six,
489 )
490 .unwrap();
491
492 assert_eq!(value, 5_000_000_000); }
494
495 #[test]
496 fn test_oracle_decimals_from_u8() {
497 assert_eq!(OracleDecimals::from(6), OracleDecimals::Six);
498 assert_eq!(OracleDecimals::from(8), OracleDecimals::Eight);
499 assert_eq!(OracleDecimals::from(18), OracleDecimals::Eighteen);
500 assert_eq!(OracleDecimals::from(12), OracleDecimals::Custom(12));
501 }
502}