Skip to main content

pyra_margin/
balance.rs

1use pyra_types::{SpotBalanceType, SpotMarket, SpotPosition};
2
3use crate::error::{MathError, MathResult};
4use crate::math::CheckedDivCeil;
5
6/// Calculate token balance from raw Drift position fields.
7///
8/// This is the low-level variant that accepts individual fields rather than
9/// typed structs, useful when the caller has a different SpotPosition/SpotMarket
10/// representation (e.g. Carbon decoder types).
11///
12/// Returns signed balance in base units (positive = deposit, negative = borrow).
13pub fn compute_token_balance(
14    scaled_balance: u64,
15    is_deposit: bool,
16    cumulative_deposit_interest: u128,
17    cumulative_borrow_interest: u128,
18    decimals: u32,
19) -> MathResult<i128> {
20    let precision_decrease = 10u128
21        .checked_pow(19u32.saturating_sub(decimals))
22        .ok_or(MathError::Overflow)?;
23
24    // u64 always fits in u128.
25    let balance = scaled_balance as u128;
26
27    let token_balance = if is_deposit {
28        let raw_balance = balance
29            .checked_mul(cumulative_deposit_interest)
30            .ok_or(MathError::Overflow)?;
31        // Safe: result of u128/u128 division is ≤ the dividend, which is bounded
32        // by scaled_balance (u64::MAX) × cumulative_interest — well within i128::MAX.
33        raw_balance
34            .checked_div(precision_decrease)
35            .ok_or(MathError::Overflow)? as i128
36    } else {
37        let raw_balance = balance
38            .checked_mul(cumulative_borrow_interest)
39            .ok_or(MathError::Overflow)?;
40        let balance_unsigned = raw_balance
41            .checked_div_ceil(precision_decrease)
42            .ok_or(MathError::Overflow)?;
43        // Safe: same bound as deposit path — division result fits in i128.
44        (balance_unsigned as i128).saturating_neg()
45    };
46
47    Ok(token_balance)
48}
49
50/// Calculate token balance from a Drift spot position and market data.
51/// Returns signed balance in base units (positive = deposit, negative = borrow).
52///
53/// This is a convenience wrapper around [`compute_token_balance`] for callers
54/// that already have `pyra_types` structs.
55pub fn get_token_balance(
56    spot_position: &SpotPosition,
57    spot_market: &SpotMarket,
58) -> MathResult<i128> {
59    let is_deposit = matches!(spot_position.balance_type, SpotBalanceType::Deposit);
60    compute_token_balance(
61        spot_position.scaled_balance,
62        is_deposit,
63        spot_market.cumulative_deposit_interest,
64        spot_market.cumulative_borrow_interest,
65        spot_market.decimals,
66    )
67}
68
69/// Calculate value of token balance in USDC base units.
70/// Takes a signed token balance and converts it to USDC value using the oracle price.
71pub fn calculate_value_usdc_base_units(
72    token_balance_base_units: i128,
73    price_usdc_base_units: u64,
74    token_decimals: u32,
75) -> MathResult<i128> {
76    let precision_decrease = 10u128
77        .checked_pow(token_decimals)
78        .ok_or(MathError::Overflow)?;
79
80    // u64 always fits in i128.
81    let value_usdc_base_units = token_balance_base_units
82        .checked_mul(price_usdc_base_units as i128)
83        .ok_or(MathError::Overflow)?
84        // u128 from checked_pow(decimals ≤ 19) always fits in i128.
85        .checked_div(precision_decrease as i128)
86        .ok_or(MathError::Overflow)?;
87
88    Ok(value_usdc_base_units)
89}
90
91#[cfg(test)]
92#[allow(
93    clippy::unwrap_used,
94    clippy::expect_used,
95    clippy::panic,
96    clippy::arithmetic_side_effects
97)]
98mod tests {
99    use super::*;
100
101    fn make_market(decimals: u32, deposit_interest: u128, borrow_interest: u128) -> SpotMarket {
102        SpotMarket {
103            pubkey: vec![],
104            market_index: 0,
105            initial_asset_weight: 0,
106            initial_liability_weight: 0,
107            imf_factor: 0,
108            scale_initial_asset_weight_start: 0,
109            decimals,
110            cumulative_deposit_interest: deposit_interest,
111            cumulative_borrow_interest: borrow_interest,
112            deposit_balance: 0,
113            borrow_balance: 0,
114            optimal_utilization: 0,
115            optimal_borrow_rate: 0,
116            max_borrow_rate: 0,
117            min_borrow_rate: 0,
118            insurance_fund: Default::default(),
119            historical_oracle_data: Default::default(),
120            oracle: None,
121        }
122    }
123
124    #[test]
125    fn deposit_balance_basic() {
126        let precision_decrease = 10u128.pow(19 - 6); // decimals=6
127        let market = make_market(6, precision_decrease, precision_decrease);
128        let position = SpotPosition {
129            scaled_balance: 1_000_000,
130            balance_type: SpotBalanceType::Deposit,
131            ..Default::default()
132        };
133        let balance = get_token_balance(&position, &market).unwrap();
134        assert_eq!(balance, 1_000_000); // 1 USDC
135    }
136
137    #[test]
138    fn borrow_balance_is_negative() {
139        let precision_decrease = 10u128.pow(19 - 6);
140        let market = make_market(6, precision_decrease, precision_decrease);
141        let position = SpotPosition {
142            scaled_balance: 500_000,
143            balance_type: SpotBalanceType::Borrow,
144            ..Default::default()
145        };
146        let balance = get_token_balance(&position, &market).unwrap();
147        assert_eq!(balance, -500_000);
148    }
149
150    #[test]
151    fn deposit_with_interest() {
152        let precision_decrease = 10u128.pow(19 - 6);
153        // 10% interest: cumulative_deposit_interest = precision_decrease * 1.1
154        let interest = precision_decrease
155            .checked_mul(11)
156            .unwrap()
157            .checked_div(10)
158            .unwrap();
159        let market = make_market(6, interest, precision_decrease);
160        let position = SpotPosition {
161            scaled_balance: 1_000_000,
162            balance_type: SpotBalanceType::Deposit,
163            ..Default::default()
164        };
165        let balance = get_token_balance(&position, &market).unwrap();
166        assert_eq!(balance, 1_100_000); // 1.1 USDC
167    }
168
169    #[test]
170    fn value_usdc_basic() {
171        // 1 SOL at $100 (price in USDC base units = 100_000_000)
172        let value = calculate_value_usdc_base_units(1_000_000_000, 100_000_000, 9).unwrap();
173        assert_eq!(value, 100_000_000); // $100 in USDC base units
174    }
175
176    #[test]
177    fn value_usdc_negative_balance() {
178        let value = calculate_value_usdc_base_units(-1_000_000_000, 100_000_000, 9).unwrap();
179        assert_eq!(value, -100_000_000);
180    }
181
182    #[test]
183    fn value_usdc_usdc_token() {
184        // 1 USDC at $1 (price = 1_000_000)
185        let value = calculate_value_usdc_base_units(1_000_000, 1_000_000, 6).unwrap();
186        assert_eq!(value, 1_000_000);
187    }
188}
189
190#[cfg(test)]
191#[allow(
192    clippy::unwrap_used,
193    clippy::expect_used,
194    clippy::panic,
195    clippy::arithmetic_side_effects
196)]
197mod proptests {
198    use super::*;
199    use proptest::prelude::*;
200
201    fn arb_market(decimals: u32) -> SpotMarket {
202        let precision_decrease = 10u128.pow(19u32.saturating_sub(decimals));
203        SpotMarket {
204            pubkey: vec![],
205            market_index: 0,
206            initial_asset_weight: 10_000,
207            initial_liability_weight: 10_000,
208            imf_factor: 0,
209            scale_initial_asset_weight_start: 0,
210            decimals,
211            cumulative_deposit_interest: precision_decrease,
212            cumulative_borrow_interest: precision_decrease,
213            deposit_balance: 0,
214            borrow_balance: 0,
215            optimal_utilization: 0,
216            optimal_borrow_rate: 0,
217            max_borrow_rate: 0,
218            min_borrow_rate: 0,
219            insurance_fund: Default::default(),
220            historical_oracle_data: Default::default(),
221            oracle: None,
222        }
223    }
224
225    proptest! {
226        #[test]
227        fn deposit_balance_always_non_negative(
228            scaled_balance in 0u64..=1_000_000_000_000u64,
229            decimals in 6u32..=9u32,
230        ) {
231            let market = arb_market(decimals);
232            let position = SpotPosition {
233                scaled_balance,
234                balance_type: SpotBalanceType::Deposit,
235                ..Default::default()
236            };
237            let balance = get_token_balance(&position, &market).unwrap();
238            prop_assert!(balance >= 0, "deposit balance {} should be >= 0", balance);
239        }
240
241        #[test]
242        fn borrow_balance_always_non_positive(
243            scaled_balance in 0u64..=1_000_000_000_000u64,
244            decimals in 6u32..=9u32,
245        ) {
246            let market = arb_market(decimals);
247            let position = SpotPosition {
248                scaled_balance,
249                balance_type: SpotBalanceType::Borrow,
250                ..Default::default()
251            };
252            let balance = get_token_balance(&position, &market).unwrap();
253            prop_assert!(balance <= 0, "borrow balance {} should be <= 0", balance);
254        }
255
256        #[test]
257        fn value_preserves_sign(
258            balance in -1_000_000_000_000i128..=1_000_000_000_000i128,
259            price in 1u64..=1_000_000_000u64,
260            decimals in 6u32..=9u32,
261        ) {
262            let value = calculate_value_usdc_base_units(balance, price, decimals).unwrap();
263            if balance > 0 {
264                prop_assert!(value >= 0, "positive balance should give non-negative value");
265            } else if balance < 0 {
266                prop_assert!(value <= 0, "negative balance should give non-positive value");
267            } else {
268                prop_assert_eq!(value, 0);
269            }
270        }
271    }
272}