drift_rs/math/
leverage.rs

1use solana_sdk::pubkey::Pubkey;
2
3use super::{
4    account_list_builder::AccountsListBuilder,
5    constants::{AMM_RESERVE_PRECISION, BASE_PRECISION, MARGIN_PRECISION, PRICE_PRECISION},
6};
7use crate::{
8    accounts::PerpMarket,
9    ffi::{
10        calculate_margin_requirement_and_total_collateral_and_liability_info, MarginCalculation,
11        MarginContextMode,
12    },
13    types::accounts::User,
14    ContractType, DriftClient, MarginMode, MarginRequirementType, MarketId, PositionDirection,
15    SdkError, SdkResult,
16};
17
18pub fn get_leverage(client: &DriftClient, user: &User) -> SdkResult<u128> {
19    let mut builder = AccountsListBuilder::default();
20    let mut accounts = builder.try_build(client, user, &[])?;
21    let margin_calculation = calculate_margin_requirement_and_total_collateral_and_liability_info(
22        user,
23        &mut accounts,
24        MarginContextMode::StandardMaintenance,
25    )?;
26
27    let net_asset_value = calculate_net_asset_value(
28        margin_calculation.total_collateral,
29        margin_calculation.total_spot_liability_value,
30    );
31
32    if net_asset_value == i128::MIN {
33        return Err(SdkError::MathError(
34            "Net asset value is less than i128::MIN".to_string(),
35        ));
36    }
37
38    let total_liability_value = margin_calculation
39        .total_perp_liability_value
40        .checked_add(margin_calculation.total_spot_liability_value)
41        .expect("fits u128");
42
43    let leverage = calculate_leverage(total_liability_value, net_asset_value);
44
45    Ok(leverage)
46}
47
48pub fn get_spot_asset_value(client: &DriftClient, user: &User) -> SdkResult<i128> {
49    let mut builder = AccountsListBuilder::default();
50    let mut accounts = builder.try_build(client, user, &[])?;
51
52    let margin_calculation = calculate_margin_requirement_and_total_collateral_and_liability_info(
53        user,
54        &mut accounts,
55        MarginContextMode::StandardMaintenance,
56    )?;
57
58    Ok(margin_calculation.total_spot_asset_value
59        - margin_calculation.total_spot_liability_value as i128)
60}
61
62fn calculate_net_asset_value(total_collateral: i128, total_spot_liability_value: u128) -> i128 {
63    if total_spot_liability_value <= i128::MAX as u128 {
64        total_collateral - total_spot_liability_value as i128
65    } else {
66        let overflow = total_spot_liability_value - i128::MAX as u128;
67        if overflow <= total_collateral as u128 + 1 {
68            total_collateral - (i128::MAX as u128 - (overflow - 1)) as i128
69        } else {
70            i128::MIN
71        }
72    }
73}
74
75fn calculate_leverage(total_liability_value: u128, net_asset_value: i128) -> u128 {
76    let sign: i128 = if net_asset_value < 0 { -1 } else { 1 };
77
78    let leverage = (total_liability_value as f64) / (net_asset_value.abs() as f64);
79
80    sign as u128 * (leverage * PRICE_PRECISION as f64) as u128
81}
82
83/// Provides margin calculation helpers for User accounts
84///
85/// sync, requires client is subscribed to necessary markets beforehand
86pub trait UserMargin {
87    /// Calculate user's max. trade size in USDC for a given market and direction
88    ///
89    /// * `user` - the user account
90    /// * `market` - the market to trade
91    /// * `trade_side` - the direction of the trade
92    ///
93    /// Returns max USDC trade size (PRICE_PRECISION)
94    fn max_trade_size(
95        &self,
96        user: &Pubkey,
97        market: MarketId,
98        trade_side: PositionDirection,
99    ) -> SdkResult<u64>;
100    fn calculate_perp_buying_power(
101        &self,
102        user: &User,
103        market: &PerpMarket,
104        oracle_price: i64,
105        collateral_buffer: u64,
106    ) -> SdkResult<u128>;
107    /// Calculate the user's live margin information
108    fn calculate_margin_info(&self, user: &User) -> SdkResult<MarginCalculation>;
109}
110
111impl UserMargin for DriftClient {
112    fn calculate_margin_info(&self, user: &User) -> SdkResult<MarginCalculation> {
113        let mut builder = AccountsListBuilder::default();
114        let mut accounts = builder.try_build(self, user, &[])?;
115        calculate_margin_requirement_and_total_collateral_and_liability_info(
116            user,
117            &mut accounts,
118            MarginContextMode::StandardMaintenance,
119        )
120    }
121    fn max_trade_size(
122        &self,
123        user: &Pubkey,
124        market: MarketId,
125        trade_side: PositionDirection,
126    ) -> SdkResult<u64> {
127        let oracle = self
128            .try_get_oracle_price_data_and_slot(market)
129            .ok_or(SdkError::NoMarketData(market))?;
130        let oracle_price = oracle.data.price;
131        let user_account = self.try_get_account::<User>(user)?;
132
133        if market.is_perp() {
134            let market_account = self.try_get_perp_market_account(market.index())?;
135
136            let position = user_account
137                .get_perp_position(market_account.market_index)
138                .map_err(|_| SdkError::NoMarketData(MarketId::perp(market_account.market_index)))?;
139            // add any position we have on the opposite side of the current trade
140            // because we can "flip" the size of this position without taking any extra leverage.
141            let is_reduce_only = position.base_asset_amount.is_negative() as u8 != trade_side as u8;
142            let opposite_side_liability_value = calculate_perp_liability_value(
143                position.base_asset_amount,
144                oracle_price,
145                market_account.contract_type == ContractType::Prediction,
146            );
147
148            let lp_buffer = ((oracle_price as u64 * market_account.amm.order_step_size)
149                / AMM_RESERVE_PRECISION as u64)
150                * position.lp_shares.max(1);
151
152            let max_position_size = self.calculate_perp_buying_power(
153                &user_account,
154                &market_account,
155                oracle_price,
156                lp_buffer,
157            )?;
158
159            Ok(max_position_size as u64 + opposite_side_liability_value * is_reduce_only as u64)
160        } else {
161            // TODO: implement for spot
162            Err(SdkError::Generic("spot market unimplemented".to_string()))
163        }
164    }
165    /// Calculate buying power = free collateral / initial margin ratio
166    ///
167    /// Returns buying power in `QUOTE_PRECISION` units
168    fn calculate_perp_buying_power(
169        &self,
170        user: &User,
171        market: &PerpMarket,
172        oracle_price: i64,
173        collateral_buffer: u64,
174    ) -> SdkResult<u128> {
175        let position = user
176            .get_perp_position(market.market_index)
177            .map_err(|_| SdkError::NoMarketData(MarketId::perp(market.market_index)))?;
178        let position_with_lp_settle =
179            position.simulate_settled_lp_position(market, oracle_price)?;
180
181        let worst_case_base_amount = position_with_lp_settle
182            .worst_case_base_asset_amount(oracle_price, market.contract_type)?;
183
184        let margin_info = self.calculate_margin_info(user)?;
185        let free_collateral = margin_info
186            .get_free_collateral()
187            .checked_sub(collateral_buffer as u128)
188            .ok_or(SdkError::MathError("underflow".to_string()))?;
189
190        let margin_ratio = market
191            .get_margin_ratio(
192                worst_case_base_amount.unsigned_abs(),
193                MarginRequirementType::Initial,
194                user.margin_mode == MarginMode::HighLeverage,
195            )
196            .expect("got margin ratio");
197        let margin_ratio = margin_ratio.max(user.max_margin_ratio);
198
199        Ok((free_collateral * MARGIN_PRECISION as u128) / margin_ratio as u128)
200    }
201}
202
203#[inline]
204pub fn calculate_perp_liability_value(
205    base_asset_amount: i64,
206    price: i64,
207    is_prediction_market: bool,
208) -> u64 {
209    let max_prediction_price = PRICE_PRECISION as i64;
210    let max_price =
211        max_prediction_price * base_asset_amount.is_negative() as i64 * is_prediction_market as i64;
212    (base_asset_amount * (max_price - price) / BASE_PRECISION as i64).unsigned_abs()
213}
214
215#[cfg(test)]
216mod tests {
217    use super::calculate_perp_liability_value;
218
219    #[test]
220    fn calculate_perp_liability_value_works() {
221        use crate::math::constants::{BASE_PRECISION_I64, PRICE_PRECISION_I64};
222        // test values taken from TS sdk
223        assert_eq!(
224            calculate_perp_liability_value(1 * BASE_PRECISION_I64, 5 * PRICE_PRECISION_I64, false),
225            5_000_000
226        );
227        assert_eq!(
228            calculate_perp_liability_value(-1 * BASE_PRECISION_I64, 5 * PRICE_PRECISION_I64, false),
229            5_000_000
230        );
231        assert_eq!(
232            calculate_perp_liability_value(-1 * BASE_PRECISION_I64, 10_000, true),
233            990_000
234        );
235        assert_eq!(
236            calculate_perp_liability_value(1 * BASE_PRECISION_I64, 90_000, true),
237            90_000
238        );
239    }
240}
241
242#[cfg(feature = "rpc_tests")]
243mod rpc_tests {
244    use solana_sdk::signature::Keypair;
245
246    use super::*;
247    use crate::{
248        utils::test_envs::{mainnet_endpoint, test_keypair},
249        Context, RpcAccountProvider, Wallet,
250    };
251
252    #[tokio::test]
253    async fn test_get_spot_market_value() {
254        let wallet: Wallet = test_keypair().into();
255        let pubkey = wallet.authority().clone();
256        let drift_client = DriftClient::new(
257            Context::MainNet,
258            RpcAccountProvider::new(&mainnet_endpoint()),
259            wallet,
260        )
261        .await
262        .expect("drift client");
263        drift_client.subscribe().await.expect("subscribe");
264
265        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
266
267        let mut user = crate::user::DriftUser::new(
268            Wallet::derive_user_account(&pubkey, 0, &constants::PROGRAM_ID),
269            drift_client.clone(),
270        )
271        .await
272        .expect("drift user");
273        user.subscribe().await.expect("subscribe");
274
275        let spot_asset_value = get_spot_asset_value(&drift_client, &user.get_user_account())
276            .expect("spot asset value");
277        println!("spot_asset_value: {}", spot_asset_value);
278    }
279
280    #[tokio::test]
281    async fn test_leverage() {
282        let wallet: Wallet = test_keypair().into();
283        let pubkey = wallet.authority().clone();
284        let drift_client = DriftClient::new(
285            Context::MainNet,
286            RpcAccountProvider::new & (mainnet_endpoint()),
287            wallet,
288        )
289        .await
290        .expect("drift client");
291        drift_client.subscribe().await.expect("subscribe");
292
293        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
294
295        let mut user = crate::user::DriftUser::new(
296            Wallet::derive_user_account(&pubkey, 0, &constants::PROGRAM_ID),
297            drift_client.clone(),
298        )
299        .await
300        .expect("drift user");
301        user.subscribe().await.expect("subscribe");
302
303        let leverage = get_leverage(&drift_client, &user.get_user_account()).expect("leverage");
304        println!("leverage: {}", leverage);
305    }
306}