1use crate::{ArithmeticError, Decimal, RoundingMode};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum OracleDecimals {
17 Six,
19 Eight,
21 Eighteen,
23 Custom(u8),
25}
26
27impl OracleDecimals {
28 pub const fn value(self) -> u8 {
30 match self {
31 Self::Six => 6,
32 Self::Eight => 8,
33 Self::Eighteen => 18,
34 Self::Custom(n) => n,
35 }
36 }
37
38 pub fn scale_factor(self) -> Decimal {
40 let decimals = self.value();
41 Decimal::from(10i64)
42 .powi(decimals as i32)
43 .unwrap_or(Decimal::MAX)
44 }
45}
46
47impl From<u8> for OracleDecimals {
48 fn from(n: u8) -> Self {
49 match n {
50 6 => Self::Six,
51 8 => Self::Eight,
52 18 => Self::Eighteen,
53 _ => Self::Custom(n),
54 }
55 }
56}
57
58pub fn normalize_oracle_price(
74 raw_value: i64,
75 decimals: OracleDecimals,
76) -> Result<Decimal, ArithmeticError> {
77 let scale = decimals.scale_factor();
78 Decimal::from(raw_value)
79 .checked_div(scale)
80 .ok_or(ArithmeticError::DivisionByZero)
81}
82
83pub fn normalize_oracle_price_i128(
88 raw_value: i128,
89 decimals: OracleDecimals,
90) -> Result<Decimal, ArithmeticError> {
91 let scale = decimals.scale_factor();
92 Decimal::try_from_i128(raw_value)?
93 .checked_div(scale)
94 .ok_or(ArithmeticError::DivisionByZero)
95}
96
97pub fn denormalize_oracle_price(
114 value: Decimal,
115 decimals: OracleDecimals,
116) -> Result<i64, ArithmeticError> {
117 let scale = decimals.scale_factor();
118 let scaled = value
119 .checked_mul(scale)
120 .ok_or(ArithmeticError::Overflow)?
121 .round(0, RoundingMode::TowardZero);
122 let (mantissa, _) = scaled.to_parts();
123 i64::try_from(mantissa).map_err(|_| ArithmeticError::Overflow)
124}
125
126pub fn denormalize_oracle_price_i128(
131 value: Decimal,
132 decimals: OracleDecimals,
133) -> Result<i128, ArithmeticError> {
134 let scale = decimals.scale_factor();
135 let scaled = value
136 .checked_mul(scale)
137 .ok_or(ArithmeticError::Overflow)?
138 .round(0, RoundingMode::TowardZero);
139 let (mantissa, _) = scaled.to_parts();
140 Ok(mantissa)
141}
142
143pub fn convert_decimals(
160 value: i64,
161 from: OracleDecimals,
162 to: OracleDecimals,
163) -> Result<i64, ArithmeticError> {
164 let from_decimals = from.value() as i32;
165 let to_decimals = to.value() as i32;
166 let diff = to_decimals - from_decimals;
167
168 if diff == 0 {
169 return Ok(value);
170 }
171
172 let factor = 10i64
173 .checked_pow(diff.unsigned_abs())
174 .ok_or(ArithmeticError::Overflow)?;
175
176 if diff > 0 {
177 value.checked_mul(factor).ok_or(ArithmeticError::Overflow)
178 } else {
179 Ok(value / factor)
180 }
181}
182
183pub fn convert_decimals_i128(
202 value: i64,
203 from: OracleDecimals,
204 to: OracleDecimals,
205) -> Result<i128, ArithmeticError> {
206 let from_decimals = from.value() as i32;
207 let to_decimals = to.value() as i32;
208 let diff = to_decimals - from_decimals;
209
210 if diff == 0 {
211 return Ok(value as i128);
212 }
213
214 let factor = 10i128
215 .checked_pow(diff.unsigned_abs())
216 .ok_or(ArithmeticError::Overflow)?;
217
218 if diff > 0 {
219 (value as i128)
220 .checked_mul(factor)
221 .ok_or(ArithmeticError::Overflow)
222 } else {
223 Ok((value as i128) / factor)
224 }
225}
226
227pub fn scale_token_amount(
247 amount: i64,
248 from_decimals: OracleDecimals,
249 to_decimals: OracleDecimals,
250) -> Result<i64, ArithmeticError> {
251 convert_decimals(amount, from_decimals, to_decimals)
252}
253
254pub fn scale_token_amount_i128(
271 amount: i64,
272 from_decimals: OracleDecimals,
273 to_decimals: OracleDecimals,
274) -> Result<i128, ArithmeticError> {
275 convert_decimals_i128(amount, from_decimals, to_decimals)
276}
277
278pub fn calculate_value(
311 amount: i64,
312 amount_decimals: OracleDecimals,
313 price: i64,
314 price_decimals: OracleDecimals,
315 result_decimals: OracleDecimals,
316) -> Result<i64, ArithmeticError> {
317 let amount_dec = normalize_oracle_price(amount, amount_decimals)?;
318 let price_dec = normalize_oracle_price(price, price_decimals)?;
319
320 let value = amount_dec
321 .checked_mul(price_dec)
322 .ok_or(ArithmeticError::Overflow)?;
323
324 denormalize_oracle_price(value, result_decimals)
325}
326
327pub fn calculate_value_i128(
329 amount: i64,
330 amount_decimals: OracleDecimals,
331 price: i64,
332 price_decimals: OracleDecimals,
333 result_decimals: OracleDecimals,
334) -> Result<i128, ArithmeticError> {
335 let amount_dec = normalize_oracle_price(amount, amount_decimals)?;
336 let price_dec = normalize_oracle_price(price, price_decimals)?;
337
338 let value = amount_dec
339 .checked_mul(price_dec)
340 .ok_or(ArithmeticError::Overflow)?;
341
342 denormalize_oracle_price_i128(value, result_decimals)
343}
344
345pub fn normalize_pyth_price(price: i64, exponent: i32) -> Result<Decimal, ArithmeticError> {
362 let price_dec = Decimal::from(price);
363
364 if exponent == 0 {
365 return Ok(price_dec);
366 }
367
368 let scale = Decimal::from(10i64)
369 .powi(exponent.abs())
370 .ok_or(ArithmeticError::Overflow)?;
371
372 if exponent > 0 {
373 price_dec
374 .checked_mul(scale)
375 .ok_or(ArithmeticError::Overflow)
376 } else {
377 price_dec
378 .checked_div(scale)
379 .ok_or(ArithmeticError::DivisionByZero)
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 extern crate alloc;
386
387 use super::*;
388 use alloc::string::ToString;
389 use core::str::FromStr;
390
391 #[test]
392 fn test_normalize_chainlink_price() {
393 let raw = 250012345678i64;
394 let price = normalize_oracle_price(raw, OracleDecimals::Eight).unwrap();
395 assert_eq!(price.to_string(), "2500.12345678");
396 }
397
398 #[test]
399 fn test_denormalize_price() {
400 let price = Decimal::from_str("2500.12345678").unwrap();
401 let raw = denormalize_oracle_price(price, OracleDecimals::Eight).unwrap();
402 assert_eq!(raw, 250012345678);
403 }
404
405 #[test]
406 fn test_convert_8_to_6_decimals() {
407 let chainlink = 250012345678i64;
408 let usdc = convert_decimals(chainlink, OracleDecimals::Eight, OracleDecimals::Six).unwrap();
409 assert_eq!(usdc, 2500123456);
410 }
411
412 #[test]
413 fn test_convert_8_to_18_decimals_i128() {
414 let chainlink = 250012345678i64;
415 let onchain =
416 convert_decimals_i128(chainlink, OracleDecimals::Eight, OracleDecimals::Eighteen)
417 .unwrap();
418 assert_eq!(onchain, 2500123456780000000000i128);
419 }
420
421 #[test]
422 fn test_convert_18_to_8_decimals_via_normalize() {
423 let original = 250012345678i64;
425 let normalized = normalize_oracle_price(original, OracleDecimals::Eight).unwrap();
426 let recovered = denormalize_oracle_price(normalized, OracleDecimals::Eight).unwrap();
427 assert_eq!(recovered, original);
428 }
429
430 #[test]
431 fn test_scale_usdc_to_8_decimals() {
432 let usdc = 1_000_000_000i64; let scaled = scale_token_amount(usdc, OracleDecimals::Six, OracleDecimals::Eight).unwrap();
434 assert_eq!(scaled, 100_000_000_000);
435 }
436
437 #[test]
438 fn test_scale_usdc_to_18_decimals_i128() {
439 let usdc = 1_000_000_000i64; let scaled =
441 scale_token_amount_i128(usdc, OracleDecimals::Six, OracleDecimals::Eighteen).unwrap();
442 assert_eq!(scaled, 1_000_000_000_000_000_000_000i128);
443 }
444
445 #[test]
446 fn test_pyth_positive_exponent() {
447 let price = normalize_pyth_price(25, 2).unwrap();
448 assert_eq!(price.to_string(), "2500");
449 }
450
451 #[test]
452 fn test_pyth_negative_exponent() {
453 let price = normalize_pyth_price(250012345678, -8).unwrap();
454 assert_eq!(price.to_string(), "2500.12345678");
455 }
456
457 #[test]
458 fn test_pyth_zero_exponent() {
459 let price = normalize_pyth_price(2500, 0).unwrap();
460 assert_eq!(price.to_string(), "2500");
461 }
462
463 #[test]
464 fn test_calculate_usdc_value() {
465 let usdc_amount = 1_000_000_000i64; let usdc_price = 100000000i64; let value = calculate_value(
470 usdc_amount,
471 OracleDecimals::Six,
472 usdc_price,
473 OracleDecimals::Eight,
474 OracleDecimals::Six,
475 )
476 .unwrap();
477
478 assert_eq!(value, 1_000_000_000); }
480
481 #[test]
482 fn test_calculate_btc_value() {
483 let btc_amount = 10_000_000i64; let btc_price = 5000000000000i64; let value = calculate_value(
489 btc_amount,
490 OracleDecimals::Eight,
491 btc_price,
492 OracleDecimals::Eight,
493 OracleDecimals::Six,
494 )
495 .unwrap();
496
497 assert_eq!(value, 5_000_000_000); }
499
500 #[test]
501 fn test_oracle_decimals_from_u8() {
502 assert_eq!(OracleDecimals::from(6), OracleDecimals::Six);
503 assert_eq!(OracleDecimals::from(8), OracleDecimals::Eight);
504 assert_eq!(OracleDecimals::from(18), OracleDecimals::Eighteen);
505 assert_eq!(OracleDecimals::from(12), OracleDecimals::Custom(12));
506 }
507}