1use std::collections::HashMap;
7
8use pyra_tokens::AssetId;
9
10use pyra_margin::get_token_balance;
11use pyra_types::{Cache, DriftUser, SpotMarket, Vault};
12
13use crate::{RedisClient, RedisError, RedisKey, RedisResult};
14
15fn balance_to_value_cents(token_balance: i128, decimals: u32, price: f64) -> RedisResult<i64> {
19 let decimals_pow = 10_f64.powi(i32::try_from(decimals).map_err(|_| RedisError::MathOverflow)?);
20 let value = (token_balance as f64) / decimals_pow * price * 100.0;
21 let rounded = value.round();
22 if rounded.is_finite() && rounded >= i64::MIN as f64 && rounded <= i64::MAX as f64 {
23 Ok(rounded as i64)
24 } else {
25 Err(RedisError::MathOverflow)
26 }
27}
28
29pub struct VaultPositionData {
33 pub drift_user: Cache<DriftUser>,
35 pub spot_markets: HashMap<AssetId, SpotMarket>,
37 pub prices: HashMap<AssetId, f64>,
39}
40
41pub struct AllDriftPositionsData {
43 pub drift_users: Vec<(String, Cache<DriftUser>)>,
45 pub spot_markets: HashMap<AssetId, SpotMarket>,
47 pub prices: HashMap<AssetId, f64>,
49 pub vault_owners: HashMap<String, String>,
51 pub skipped_drift_users: usize,
54}
55
56impl RedisClient {
59 pub async fn fetch_vault_position_data(
66 &self,
67 vault_address: &str,
68 asset_ids: &[AssetId],
69 ) -> RedisResult<VaultPositionData> {
70 let drift_user_key = format!("{}:{vault_address}", RedisKey::DRIFT_USER_PREFIX);
71 let mut keys = vec![drift_user_key];
72 for &id in asset_ids {
73 keys.push(RedisKey::drift_spot_market(id).to_string());
74 }
75 for &id in asset_ids {
76 keys.push(RedisKey::price(id).to_string());
77 }
78
79 let values = self.mget(&keys).await?;
80
81 let drift_user_raw = values
82 .first()
83 .and_then(|v| v.as_ref())
84 .ok_or_else(|| RedisError::NotFound("DriftUser not found in Redis".into()))?;
85 let drift_user: Cache<DriftUser> = serde_json::from_str(drift_user_raw)?;
86
87 let num_markets = asset_ids.len();
88 let mut spot_markets: HashMap<AssetId, SpotMarket> = HashMap::new();
89 let mut prices: HashMap<AssetId, f64> = HashMap::new();
90
91 for (i, &id) in asset_ids.iter().enumerate() {
92 if let Some(Some(raw)) =
93 values.get(1usize.checked_add(i).ok_or(RedisError::MathOverflow)?)
94 {
95 if let Ok(cache) = serde_json::from_str::<Cache<SpotMarket>>(raw) {
96 spot_markets.insert(id, cache.account);
97 }
98 }
99 if let Some(Some(raw)) = values.get(
100 1usize
101 .checked_add(num_markets)
102 .ok_or(RedisError::MathOverflow)?
103 .checked_add(i)
104 .ok_or(RedisError::MathOverflow)?,
105 ) {
106 if let Ok(price) = serde_json::from_str::<f64>(raw) {
107 prices.insert(id, price);
108 }
109 }
110 }
111
112 Ok(VaultPositionData {
113 drift_user,
114 spot_markets,
115 prices,
116 })
117 }
118
119 pub async fn fetch_all_drift_positions(
127 &self,
128 asset_ids: &[AssetId],
129 include_vault_owners: bool,
130 ) -> RedisResult<AllDriftPositionsData> {
131 let drift_keys = self
133 .scan_keys(&RedisKey::pattern(RedisKey::DRIFT_USER_PREFIX))
134 .await?;
135
136 let prefix_with_colon = format!("{}:", RedisKey::DRIFT_USER_PREFIX);
138 let vault_addresses: Vec<&str> = drift_keys
139 .iter()
140 .filter_map(|k| k.strip_prefix(prefix_with_colon.as_str()))
141 .collect();
142
143 let num_drift = drift_keys.len();
146 let mut all_keys: Vec<String> = drift_keys.clone();
147
148 if include_vault_owners {
149 for vault_addr in &vault_addresses {
150 all_keys.push(format!("{}:{vault_addr}", RedisKey::VAULT_PREFIX));
151 }
152 }
153 for &id in asset_ids {
154 all_keys.push(RedisKey::drift_spot_market(id).to_string());
155 }
156 for &id in asset_ids {
157 all_keys.push(RedisKey::price(id).to_string());
158 }
159
160 let values = self.mget(&all_keys).await?;
161
162 let mut drift_users: Vec<(String, Cache<DriftUser>)> = Vec::new();
164 let mut skipped_drift_users: usize = 0;
165 for (i, vault_addr) in vault_addresses.iter().enumerate() {
166 if let Some(Some(raw)) = values.get(i) {
167 match serde_json::from_str::<Cache<DriftUser>>(raw) {
168 Ok(du) => drift_users.push(((*vault_addr).to_string(), du)),
169 Err(_) => {
170 skipped_drift_users = skipped_drift_users.saturating_add(1);
171 }
172 }
173 }
174 }
175
176 let mut vault_owners: HashMap<String, String> = HashMap::new();
178 if include_vault_owners {
179 for (i, vault_addr) in vault_addresses.iter().enumerate() {
180 let offset = num_drift.checked_add(i).ok_or(RedisError::MathOverflow)?;
181 if let Some(Some(raw)) = values.get(offset) {
182 if let Ok(vault_cache) = serde_json::from_str::<Cache<Vault>>(raw) {
183 vault_owners.insert(
184 (*vault_addr).to_string(),
185 vault_cache.account.owner.to_string(),
186 );
187 }
188 }
189 }
190 }
191
192 let vault_count = if include_vault_owners { num_drift } else { 0 };
194 let market_base = num_drift
195 .checked_add(vault_count)
196 .ok_or(RedisError::MathOverflow)?;
197 let num_markets = asset_ids.len();
198
199 let mut spot_markets: HashMap<AssetId, SpotMarket> = HashMap::new();
200 let mut prices: HashMap<AssetId, f64> = HashMap::new();
201
202 for (i, &id) in asset_ids.iter().enumerate() {
203 let market_offset = market_base.checked_add(i).ok_or(RedisError::MathOverflow)?;
204 if let Some(Some(raw)) = values.get(market_offset) {
205 if let Ok(cache) = serde_json::from_str::<Cache<SpotMarket>>(raw) {
206 spot_markets.insert(id, cache.account);
207 }
208 }
209
210 let price_offset = market_base
211 .checked_add(num_markets)
212 .ok_or(RedisError::MathOverflow)?
213 .checked_add(i)
214 .ok_or(RedisError::MathOverflow)?;
215 if let Some(Some(raw)) = values.get(price_offset) {
216 if let Ok(price) = serde_json::from_str::<f64>(raw) {
217 prices.insert(id, price);
218 }
219 }
220 }
221
222 Ok(AllDriftPositionsData {
223 drift_users,
224 spot_markets,
225 prices,
226 vault_owners,
227 skipped_drift_users,
228 })
229 }
230}
231
232pub fn compute_position_values(data: &VaultPositionData) -> RedisResult<Vec<i64>> {
238 compute_user_position_values(&data.drift_user.account, &data.spot_markets, &data.prices)
239}
240
241pub fn compute_asset_data(data: &VaultPositionData) -> RedisResult<Vec<(AssetId, i64, i64)>> {
243 compute_user_asset_data(&data.drift_user.account, &data.spot_markets, &data.prices)
244}
245
246pub fn compute_user_position_values(
249 drift_user: &DriftUser,
250 spot_markets: &HashMap<AssetId, SpotMarket>,
251 prices: &HashMap<AssetId, f64>,
252) -> RedisResult<Vec<i64>> {
253 let mut results = Vec::new();
254 for position in &drift_user.spot_positions {
255 if position.scaled_balance == 0 {
256 continue;
257 }
258 let Some(token) = pyra_tokens::Token::find_by_drift_market_index(position.market_index)
259 else {
260 continue;
261 };
262 let asset_id = token.asset_id;
263 let Some(market) = spot_markets.get(&asset_id) else {
264 continue;
265 };
266 let Some(&price) = prices.get(&asset_id) else {
267 continue;
268 };
269 let token_balance = get_token_balance(position, market)?;
270 let value_cents = balance_to_value_cents(token_balance, market.decimals, price)?;
271 results.push(value_cents);
272 }
273 Ok(results)
274}
275
276pub fn compute_user_asset_data(
280 drift_user: &DriftUser,
281 spot_markets: &HashMap<AssetId, SpotMarket>,
282 prices: &HashMap<AssetId, f64>,
283) -> RedisResult<Vec<(AssetId, i64, i64)>> {
284 let mut results = Vec::new();
285 for position in &drift_user.spot_positions {
286 if position.scaled_balance == 0 {
287 continue;
288 }
289 let Some(token) = pyra_tokens::Token::find_by_drift_market_index(position.market_index)
290 else {
291 continue;
292 };
293 let asset_id = token.asset_id;
294 let Some(market) = spot_markets.get(&asset_id) else {
295 continue;
296 };
297 let Some(&price) = prices.get(&asset_id) else {
298 continue;
299 };
300 let token_balance_i128 = get_token_balance(position, market)?;
301 let token_balance =
302 i64::try_from(token_balance_i128).map_err(|_| RedisError::MathOverflow)?;
303 let value_cents = balance_to_value_cents(token_balance_i128, market.decimals, price)?;
304 results.push((asset_id, token_balance, value_cents));
305 }
306 Ok(results)
307}
308
309#[cfg(test)]
310#[allow(
311 clippy::allow_attributes,
312 clippy::allow_attributes_without_reason,
313 clippy::unwrap_used,
314 clippy::expect_used,
315 clippy::panic,
316 clippy::arithmetic_side_effects
317)]
318mod tests {
319 use super::*;
320 use pyra_types::SpotBalanceType;
321
322 fn make_spot_position(market_index: u16, scaled_balance: u64) -> pyra_types::SpotPosition {
323 pyra_types::SpotPosition {
324 market_index,
325 scaled_balance,
326 balance_type: SpotBalanceType::Deposit,
327 ..Default::default()
328 }
329 }
330
331 fn make_spot_market(market_index: u16, decimals: u32) -> SpotMarket {
332 let precision = 10u128.pow(19u32.saturating_sub(decimals));
334 SpotMarket {
335 pubkey: vec![],
336 market_index,
337 initial_asset_weight: 0,
338 initial_liability_weight: 0,
339 imf_factor: 0,
340 scale_initial_asset_weight_start: 0,
341 decimals,
342 cumulative_deposit_interest: precision,
343 cumulative_borrow_interest: precision,
344 deposit_balance: 0,
345 borrow_balance: 0,
346 optimal_utilization: 0,
347 optimal_borrow_rate: 0,
348 max_borrow_rate: 0,
349 min_borrow_rate: 0,
350 insurance_fund: Default::default(),
351 historical_oracle_data: Default::default(),
352 oracle: None,
353 }
354 }
355
356 fn make_drift_user(positions: Vec<pyra_types::SpotPosition>) -> DriftUser {
357 DriftUser {
358 authority: Default::default(),
359 spot_positions: positions,
360 }
361 }
362
363 #[test]
364 fn compute_position_values_basic() {
365 let drift_user = make_drift_user(vec![make_spot_position(0, 1_000_000)]);
366 let mut spot_markets = HashMap::new();
367 spot_markets.insert(AssetId::new(0).unwrap(), make_spot_market(0, 6));
368 let mut prices = HashMap::new();
369 prices.insert(AssetId::new(0).unwrap(), 1.0);
370
371 let values = compute_user_position_values(&drift_user, &spot_markets, &prices).unwrap();
372 assert_eq!(values.len(), 1);
373 assert_eq!(values[0], 100); }
375
376 #[test]
377 fn compute_position_values_multiple_markets() {
378 let drift_user = make_drift_user(vec![
379 make_spot_position(0, 2_000_000), make_spot_position(1, 100_000_000), ]);
382 let mut spot_markets = HashMap::new();
383 spot_markets.insert(AssetId::new(0).unwrap(), make_spot_market(0, 6));
384 spot_markets.insert(AssetId::new(1).unwrap(), make_spot_market(1, 9));
385 let mut prices = HashMap::new();
386 prices.insert(AssetId::new(0).unwrap(), 1.0);
387 prices.insert(AssetId::new(1).unwrap(), 150.0);
388
389 let values = compute_user_position_values(&drift_user, &spot_markets, &prices).unwrap();
390 assert_eq!(values.len(), 2);
391 assert_eq!(values[0], 200); assert_eq!(values[1], 1500); }
394
395 #[test]
396 fn compute_position_values_skips_zero_balance() {
397 let drift_user = make_drift_user(vec![
398 make_spot_position(0, 0), make_spot_position(1, 1_000_000), ]);
401 let mut spot_markets = HashMap::new();
402 spot_markets.insert(AssetId::new(0).unwrap(), make_spot_market(0, 6));
403 spot_markets.insert(AssetId::new(1).unwrap(), make_spot_market(1, 6));
404 let mut prices = HashMap::new();
405 prices.insert(AssetId::new(0).unwrap(), 1.0);
406 prices.insert(AssetId::new(1).unwrap(), 1.0);
407
408 let values = compute_user_position_values(&drift_user, &spot_markets, &prices).unwrap();
409 assert_eq!(values.len(), 1);
410 assert_eq!(values[0], 100);
411 }
412
413 #[test]
414 fn compute_position_values_skips_missing_market() {
415 let drift_user = make_drift_user(vec![make_spot_position(99, 1_000_000)]);
416 let spot_markets = HashMap::new();
417 let prices = HashMap::new();
418
419 let values = compute_user_position_values(&drift_user, &spot_markets, &prices).unwrap();
420 assert!(values.is_empty());
421 }
422
423 #[test]
424 fn compute_position_values_skips_missing_price() {
425 let drift_user = make_drift_user(vec![make_spot_position(0, 1_000_000)]);
426 let mut spot_markets = HashMap::new();
427 spot_markets.insert(AssetId::new(0).unwrap(), make_spot_market(0, 6));
428 let prices = HashMap::new();
429
430 let values = compute_user_position_values(&drift_user, &spot_markets, &prices).unwrap();
431 assert!(values.is_empty());
432 }
433
434 #[test]
435 fn compute_asset_data_returns_tuples() {
436 let drift_user = make_drift_user(vec![make_spot_position(0, 5_000_000)]);
437 let mut spot_markets = HashMap::new();
438 spot_markets.insert(AssetId::new(0).unwrap(), make_spot_market(0, 6));
439 let mut prices = HashMap::new();
440 prices.insert(AssetId::new(0).unwrap(), 1.0);
441
442 let data = compute_user_asset_data(&drift_user, &spot_markets, &prices).unwrap();
443 assert_eq!(data.len(), 1);
444 let (asset_id, token_balance, value_cents) = data[0];
445 assert_eq!(asset_id, AssetId::new(0).unwrap());
446 assert_eq!(token_balance, 5_000_000);
447 assert_eq!(value_cents, 500);
448 }
449
450 #[test]
451 fn compute_position_values_delegates_to_user_variant() {
452 let drift_user = Cache {
453 account: make_drift_user(vec![make_spot_position(0, 1_000_000)]),
454 last_updated_slot: 12345,
455 };
456 let mut spot_markets = HashMap::new();
457 spot_markets.insert(AssetId::new(0).unwrap(), make_spot_market(0, 6));
458 let mut prices = HashMap::new();
459 prices.insert(AssetId::new(0).unwrap(), 1.0);
460
461 let vpd = VaultPositionData {
462 drift_user,
463 spot_markets,
464 prices,
465 };
466 let values = compute_position_values(&vpd).unwrap();
467 assert_eq!(values.len(), 1);
468 assert_eq!(values[0], 100);
469 }
470}