1use solana_sdk::pubkey::Pubkey;
2
3use super::{
4 account_list_builder::AccountsListBuilder,
5 constants::{AMM_RESERVE_PRECISION, BASE_PRECISION, MARGIN_PRECISION, PRICE_PRECISION},
6};
7use crate::{
8 accounts::PerpMarket,
9 ffi::{
10 calculate_margin_requirement_and_total_collateral_and_liability_info, MarginCalculation,
11 MarginContextMode,
12 },
13 types::accounts::User,
14 ContractType, DriftClient, MarginMode, MarginRequirementType, MarketId, PositionDirection,
15 SdkError, SdkResult,
16};
17
18pub fn get_leverage(client: &DriftClient, user: &User) -> SdkResult<u128> {
19 let mut builder = AccountsListBuilder::default();
20 let mut accounts = builder.try_build(client, user, &[])?;
21 let margin_calculation = calculate_margin_requirement_and_total_collateral_and_liability_info(
22 user,
23 &mut accounts,
24 MarginContextMode::StandardMaintenance,
25 )?;
26
27 let net_asset_value = calculate_net_asset_value(
28 margin_calculation.total_collateral,
29 margin_calculation.total_spot_liability_value,
30 );
31
32 if net_asset_value == i128::MIN {
33 return Err(SdkError::MathError(
34 "Net asset value is less than i128::MIN".to_string(),
35 ));
36 }
37
38 let total_liability_value = margin_calculation
39 .total_perp_liability_value
40 .checked_add(margin_calculation.total_spot_liability_value)
41 .expect("fits u128");
42
43 let leverage = calculate_leverage(total_liability_value, net_asset_value);
44
45 Ok(leverage)
46}
47
48pub fn get_spot_asset_value(client: &DriftClient, user: &User) -> SdkResult<i128> {
49 let mut builder = AccountsListBuilder::default();
50 let mut accounts = builder.try_build(client, user, &[])?;
51
52 let margin_calculation = calculate_margin_requirement_and_total_collateral_and_liability_info(
53 user,
54 &mut accounts,
55 MarginContextMode::StandardMaintenance,
56 )?;
57
58 Ok(margin_calculation.total_spot_asset_value
59 - margin_calculation.total_spot_liability_value as i128)
60}
61
62fn calculate_net_asset_value(total_collateral: i128, total_spot_liability_value: u128) -> i128 {
63 if total_spot_liability_value <= i128::MAX as u128 {
64 total_collateral - total_spot_liability_value as i128
65 } else {
66 let overflow = total_spot_liability_value - i128::MAX as u128;
67 if overflow <= total_collateral as u128 + 1 {
68 total_collateral - (i128::MAX as u128 - (overflow - 1)) as i128
69 } else {
70 i128::MIN
71 }
72 }
73}
74
75fn calculate_leverage(total_liability_value: u128, net_asset_value: i128) -> u128 {
76 let sign: i128 = if net_asset_value < 0 { -1 } else { 1 };
77
78 let leverage = (total_liability_value as f64) / (net_asset_value.abs() as f64);
79
80 sign as u128 * (leverage * PRICE_PRECISION as f64) as u128
81}
82
83pub trait UserMargin {
87 fn max_trade_size(
95 &self,
96 user: &Pubkey,
97 market: MarketId,
98 trade_side: PositionDirection,
99 ) -> SdkResult<u64>;
100 fn calculate_perp_buying_power(
101 &self,
102 user: &User,
103 market: &PerpMarket,
104 oracle_price: i64,
105 collateral_buffer: u64,
106 ) -> SdkResult<u128>;
107 fn calculate_margin_info(&self, user: &User) -> SdkResult<MarginCalculation>;
109}
110
111impl UserMargin for DriftClient {
112 fn calculate_margin_info(&self, user: &User) -> SdkResult<MarginCalculation> {
113 let mut builder = AccountsListBuilder::default();
114 let mut accounts = builder.try_build(self, user, &[])?;
115 calculate_margin_requirement_and_total_collateral_and_liability_info(
116 user,
117 &mut accounts,
118 MarginContextMode::StandardMaintenance,
119 )
120 }
121 fn max_trade_size(
122 &self,
123 user: &Pubkey,
124 market: MarketId,
125 trade_side: PositionDirection,
126 ) -> SdkResult<u64> {
127 let oracle = self
128 .try_get_oracle_price_data_and_slot(market)
129 .ok_or(SdkError::NoMarketData(market))?;
130 let oracle_price = oracle.data.price;
131 let user_account = self.try_get_account::<User>(user)?;
132
133 if market.is_perp() {
134 let market_account = self.try_get_perp_market_account(market.index())?;
135
136 let position = user_account
137 .get_perp_position(market_account.market_index)
138 .map_err(|_| SdkError::NoMarketData(MarketId::perp(market_account.market_index)))?;
139 let is_reduce_only = position.base_asset_amount.is_negative() as u8 != trade_side as u8;
142 let opposite_side_liability_value = calculate_perp_liability_value(
143 position.base_asset_amount,
144 oracle_price,
145 market_account.contract_type == ContractType::Prediction,
146 );
147
148 let lp_buffer = ((oracle_price as u64 * market_account.amm.order_step_size)
149 / AMM_RESERVE_PRECISION as u64)
150 * position.lp_shares.max(1);
151
152 let max_position_size = self.calculate_perp_buying_power(
153 &user_account,
154 &market_account,
155 oracle_price,
156 lp_buffer,
157 )?;
158
159 Ok(max_position_size as u64 + opposite_side_liability_value * is_reduce_only as u64)
160 } else {
161 Err(SdkError::Generic("spot market unimplemented".to_string()))
163 }
164 }
165 fn calculate_perp_buying_power(
169 &self,
170 user: &User,
171 market: &PerpMarket,
172 oracle_price: i64,
173 collateral_buffer: u64,
174 ) -> SdkResult<u128> {
175 let position = user
176 .get_perp_position(market.market_index)
177 .map_err(|_| SdkError::NoMarketData(MarketId::perp(market.market_index)))?;
178 let position_with_lp_settle =
179 position.simulate_settled_lp_position(market, oracle_price)?;
180
181 let worst_case_base_amount = position_with_lp_settle
182 .worst_case_base_asset_amount(oracle_price, market.contract_type)?;
183
184 let margin_info = self.calculate_margin_info(user)?;
185 let free_collateral = margin_info
186 .get_free_collateral()
187 .checked_sub(collateral_buffer as u128)
188 .ok_or(SdkError::MathError("underflow".to_string()))?;
189
190 let margin_ratio = market
191 .get_margin_ratio(
192 worst_case_base_amount.unsigned_abs(),
193 MarginRequirementType::Initial,
194 user.margin_mode == MarginMode::HighLeverage,
195 )
196 .expect("got margin ratio");
197 let margin_ratio = margin_ratio.max(user.max_margin_ratio);
198
199 Ok((free_collateral * MARGIN_PRECISION as u128) / margin_ratio as u128)
200 }
201}
202
203#[inline]
204pub fn calculate_perp_liability_value(
205 base_asset_amount: i64,
206 price: i64,
207 is_prediction_market: bool,
208) -> u64 {
209 let max_prediction_price = PRICE_PRECISION as i64;
210 let max_price =
211 max_prediction_price * base_asset_amount.is_negative() as i64 * is_prediction_market as i64;
212 (base_asset_amount * (max_price - price) / BASE_PRECISION as i64).unsigned_abs()
213}
214
215#[cfg(test)]
216mod tests {
217 use super::calculate_perp_liability_value;
218
219 #[test]
220 fn calculate_perp_liability_value_works() {
221 use crate::math::constants::{BASE_PRECISION_I64, PRICE_PRECISION_I64};
222 assert_eq!(
224 calculate_perp_liability_value(1 * BASE_PRECISION_I64, 5 * PRICE_PRECISION_I64, false),
225 5_000_000
226 );
227 assert_eq!(
228 calculate_perp_liability_value(-1 * BASE_PRECISION_I64, 5 * PRICE_PRECISION_I64, false),
229 5_000_000
230 );
231 assert_eq!(
232 calculate_perp_liability_value(-1 * BASE_PRECISION_I64, 10_000, true),
233 990_000
234 );
235 assert_eq!(
236 calculate_perp_liability_value(1 * BASE_PRECISION_I64, 90_000, true),
237 90_000
238 );
239 }
240}
241
242#[cfg(feature = "rpc_tests")]
243mod rpc_tests {
244 use solana_sdk::signature::Keypair;
245
246 use super::*;
247 use crate::{
248 utils::test_envs::{mainnet_endpoint, test_keypair},
249 Context, RpcAccountProvider, Wallet,
250 };
251
252 #[tokio::test]
253 async fn test_get_spot_market_value() {
254 let wallet: Wallet = test_keypair().into();
255 let pubkey = wallet.authority().clone();
256 let drift_client = DriftClient::new(
257 Context::MainNet,
258 RpcAccountProvider::new(&mainnet_endpoint()),
259 wallet,
260 )
261 .await
262 .expect("drift client");
263 drift_client.subscribe().await.expect("subscribe");
264
265 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
266
267 let mut user = crate::user::DriftUser::new(
268 Wallet::derive_user_account(&pubkey, 0, &constants::PROGRAM_ID),
269 drift_client.clone(),
270 )
271 .await
272 .expect("drift user");
273 user.subscribe().await.expect("subscribe");
274
275 let spot_asset_value = get_spot_asset_value(&drift_client, &user.get_user_account())
276 .expect("spot asset value");
277 println!("spot_asset_value: {}", spot_asset_value);
278 }
279
280 #[tokio::test]
281 async fn test_leverage() {
282 let wallet: Wallet = test_keypair().into();
283 let pubkey = wallet.authority().clone();
284 let drift_client = DriftClient::new(
285 Context::MainNet,
286 RpcAccountProvider::new & (mainnet_endpoint()),
287 wallet,
288 )
289 .await
290 .expect("drift client");
291 drift_client.subscribe().await.expect("subscribe");
292
293 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
294
295 let mut user = crate::user::DriftUser::new(
296 Wallet::derive_user_account(&pubkey, 0, &constants::PROGRAM_ID),
297 drift_client.clone(),
298 )
299 .await
300 .expect("drift user");
301 user.subscribe().await.expect("subscribe");
302
303 let leverage = get_leverage(&drift_client, &user.get_user_account()).expect("leverage");
304 println!("leverage: {}", leverage);
305 }
306}