use std::collections::HashMap;
use solana_pubkey::Pubkey;
use pyra_margin::get_token_balance;
use pyra_types::{Cache, DriftUser, SpotMarket, Vault};
use crate::{RedisClient, RedisError, RedisKey, RedisResult};
fn balance_to_cents(token_balance: i128, decimals: u32, price: f64) -> RedisResult<i64> {
let decimals_pow =
10_f64.powi(i32::try_from(decimals).map_err(|_| RedisError::MathOverflow)?);
let value = (token_balance as f64) / decimals_pow * price * 100.0;
let rounded = value.round();
if rounded.is_finite() && rounded >= i64::MIN as f64 && rounded <= i64::MAX as f64 {
Ok(rounded as i64)
} else {
Err(RedisError::MathOverflow)
}
}
pub struct VaultPositionData {
pub drift_user: Cache<DriftUser>,
pub spot_markets: HashMap<u16, SpotMarket>,
pub prices: HashMap<u16, f64>,
}
pub struct AllDriftPositionsData {
pub drift_users: Vec<(String, Cache<DriftUser>)>,
pub spot_markets: HashMap<u16, SpotMarket>,
pub prices: HashMap<u16, f64>,
pub vault_owners: HashMap<String, String>,
pub skipped_drift_users: usize,
}
impl RedisClient {
pub async fn fetch_vault_position_data(
&self,
vault_address: &str,
market_indices: &[u16],
) -> RedisResult<VaultPositionData> {
let drift_user_key = format!("{}:{vault_address}", RedisKey::DRIFT_USER_PREFIX);
let mut keys = vec![drift_user_key];
for idx in market_indices {
keys.push(RedisKey::drift_spot_market(*idx).to_string());
}
for idx in market_indices {
keys.push(RedisKey::price(*idx).to_string());
}
let values = self.mget(&keys).await?;
let drift_user_raw = values
.first()
.and_then(|v| v.as_ref())
.ok_or_else(|| RedisError::NotFound("DriftUser not found in Redis".into()))?;
let drift_user: Cache<DriftUser> = serde_json::from_str(drift_user_raw)?;
let num_markets = market_indices.len();
let mut spot_markets: HashMap<u16, SpotMarket> = HashMap::new();
let mut prices: HashMap<u16, f64> = HashMap::new();
for (i, idx) in market_indices.iter().enumerate() {
if let Some(Some(raw)) =
values.get(1usize.checked_add(i).ok_or(RedisError::MathOverflow)?)
{
if let Ok(cache) = serde_json::from_str::<Cache<SpotMarket>>(raw) {
spot_markets.insert(*idx, cache.account);
}
}
if let Some(Some(raw)) = values.get(
1usize
.checked_add(num_markets)
.ok_or(RedisError::MathOverflow)?
.checked_add(i)
.ok_or(RedisError::MathOverflow)?,
) {
if let Ok(price) = serde_json::from_str::<f64>(raw) {
prices.insert(*idx, price);
}
}
}
Ok(VaultPositionData {
drift_user,
spot_markets,
prices,
})
}
pub async fn fetch_all_drift_positions(
&self,
market_indices: &[u16],
include_vault_owners: bool,
) -> RedisResult<AllDriftPositionsData> {
let drift_keys = self
.scan_keys(&RedisKey::pattern(RedisKey::DRIFT_USER_PREFIX))
.await?;
let prefix_with_colon = format!("{}:", RedisKey::DRIFT_USER_PREFIX);
let vault_addresses: Vec<&str> = drift_keys
.iter()
.filter_map(|k| k.strip_prefix(prefix_with_colon.as_str()))
.collect();
let num_drift = drift_keys.len();
let mut all_keys: Vec<String> = drift_keys.clone();
if include_vault_owners {
for vault_addr in &vault_addresses {
all_keys.push(format!("{}:{vault_addr}", RedisKey::VAULT_PREFIX));
}
}
for idx in market_indices {
all_keys.push(RedisKey::drift_spot_market(*idx).to_string());
}
for idx in market_indices {
all_keys.push(RedisKey::price(*idx).to_string());
}
let values = self.mget(&all_keys).await?;
let mut drift_users: Vec<(String, Cache<DriftUser>)> = Vec::new();
let mut skipped_drift_users: usize = 0;
for (i, vault_addr) in vault_addresses.iter().enumerate() {
if let Some(Some(raw)) = values.get(i) {
match serde_json::from_str::<Cache<DriftUser>>(raw) {
Ok(du) => drift_users.push(((*vault_addr).to_string(), du)),
Err(_) => {
skipped_drift_users = skipped_drift_users.saturating_add(1);
}
}
}
}
let mut vault_owners: HashMap<String, String> = HashMap::new();
if include_vault_owners {
for (i, vault_addr) in vault_addresses.iter().enumerate() {
let offset = num_drift.checked_add(i).ok_or(RedisError::MathOverflow)?;
if let Some(Some(raw)) = values.get(offset) {
if let Ok(vault_cache) = serde_json::from_str::<Cache<Vault>>(raw) {
if let Ok(bytes) =
<[u8; 32]>::try_from(vault_cache.account.owner.as_slice())
{
let owner = Pubkey::from(bytes);
vault_owners.insert((*vault_addr).to_string(), owner.to_string());
}
}
}
}
}
let vault_count = if include_vault_owners { num_drift } else { 0 };
let market_base = num_drift
.checked_add(vault_count)
.ok_or(RedisError::MathOverflow)?;
let num_markets = market_indices.len();
let mut spot_markets: HashMap<u16, SpotMarket> = HashMap::new();
let mut prices: HashMap<u16, f64> = HashMap::new();
for (i, idx) in market_indices.iter().enumerate() {
let market_offset = market_base.checked_add(i).ok_or(RedisError::MathOverflow)?;
if let Some(Some(raw)) = values.get(market_offset) {
if let Ok(cache) = serde_json::from_str::<Cache<SpotMarket>>(raw) {
spot_markets.insert(*idx, cache.account);
}
}
let price_offset = market_base
.checked_add(num_markets)
.ok_or(RedisError::MathOverflow)?
.checked_add(i)
.ok_or(RedisError::MathOverflow)?;
if let Some(Some(raw)) = values.get(price_offset) {
if let Ok(price) = serde_json::from_str::<f64>(raw) {
prices.insert(*idx, price);
}
}
}
Ok(AllDriftPositionsData {
drift_users,
spot_markets,
prices,
vault_owners,
skipped_drift_users,
})
}
}
pub fn compute_position_values(data: &VaultPositionData) -> RedisResult<Vec<i64>> {
compute_user_position_values(&data.drift_user.account, &data.spot_markets, &data.prices)
}
pub fn compute_asset_data(data: &VaultPositionData) -> RedisResult<Vec<(u16, i64, i64)>> {
compute_user_asset_data(&data.drift_user.account, &data.spot_markets, &data.prices)
}
pub fn compute_user_position_values(
drift_user: &DriftUser,
spot_markets: &HashMap<u16, SpotMarket>,
prices: &HashMap<u16, f64>,
) -> RedisResult<Vec<i64>> {
let mut results = Vec::new();
for position in &drift_user.spot_positions {
if position.scaled_balance == 0 {
continue;
}
let Some(market) = spot_markets.get(&position.market_index) else {
continue;
};
let Some(&price) = prices.get(&position.market_index) else {
continue;
};
let token_balance = get_token_balance(position, market)?;
let value_cents = balance_to_cents(token_balance, market.decimals, price)?;
results.push(value_cents);
}
Ok(results)
}
pub fn compute_user_asset_data(
drift_user: &DriftUser,
spot_markets: &HashMap<u16, SpotMarket>,
prices: &HashMap<u16, f64>,
) -> RedisResult<Vec<(u16, i64, i64)>> {
let mut results = Vec::new();
for position in &drift_user.spot_positions {
if position.scaled_balance == 0 {
continue;
}
let Some(market) = spot_markets.get(&position.market_index) else {
continue;
};
let Some(&price) = prices.get(&position.market_index) else {
continue;
};
let token_balance_i128 = get_token_balance(position, market)?;
let token_balance =
i64::try_from(token_balance_i128).map_err(|_| RedisError::MathOverflow)?;
let value_cents = balance_to_cents(token_balance_i128, market.decimals, price)?;
results.push((position.market_index, token_balance, value_cents));
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
use pyra_types::SpotBalanceType;
fn make_spot_position(market_index: u16, scaled_balance: u64) -> pyra_types::SpotPosition {
pyra_types::SpotPosition {
market_index,
scaled_balance,
balance_type: SpotBalanceType::Deposit,
..Default::default()
}
}
fn make_spot_market(market_index: u16, decimals: u32) -> SpotMarket {
let precision = 10u128.pow(19u32.saturating_sub(decimals));
SpotMarket {
pubkey: vec![],
market_index,
initial_asset_weight: 0,
initial_liability_weight: 0,
imf_factor: 0,
scale_initial_asset_weight_start: 0,
decimals,
cumulative_deposit_interest: precision,
cumulative_borrow_interest: precision,
deposit_balance: 0,
borrow_balance: 0,
optimal_utilization: 0,
optimal_borrow_rate: 0,
max_borrow_rate: 0,
min_borrow_rate: 0,
insurance_fund: Default::default(),
historical_oracle_data: Default::default(),
oracle: None,
}
}
fn make_drift_user(positions: Vec<pyra_types::SpotPosition>) -> DriftUser {
DriftUser {
authority: Default::default(),
spot_positions: positions,
}
}
#[test]
fn compute_position_values_basic() {
let drift_user = make_drift_user(vec![make_spot_position(0, 1_000_000)]);
let mut spot_markets = HashMap::new();
spot_markets.insert(0, make_spot_market(0, 6));
let mut prices = HashMap::new();
prices.insert(0, 1.0);
let values = compute_user_position_values(&drift_user, &spot_markets, &prices).unwrap();
assert_eq!(values.len(), 1);
assert_eq!(values[0], 100); }
#[test]
fn compute_position_values_multiple_markets() {
let drift_user = make_drift_user(vec![
make_spot_position(0, 2_000_000), make_spot_position(1, 100_000_000), ]);
let mut spot_markets = HashMap::new();
spot_markets.insert(0, make_spot_market(0, 6));
spot_markets.insert(1, make_spot_market(1, 9));
let mut prices = HashMap::new();
prices.insert(0, 1.0);
prices.insert(1, 150.0);
let values = compute_user_position_values(&drift_user, &spot_markets, &prices).unwrap();
assert_eq!(values.len(), 2);
assert_eq!(values[0], 200); assert_eq!(values[1], 1500); }
#[test]
fn compute_position_values_skips_zero_balance() {
let drift_user = make_drift_user(vec![
make_spot_position(0, 0), make_spot_position(1, 1_000_000), ]);
let mut spot_markets = HashMap::new();
spot_markets.insert(0, make_spot_market(0, 6));
spot_markets.insert(1, make_spot_market(1, 6));
let mut prices = HashMap::new();
prices.insert(0, 1.0);
prices.insert(1, 1.0);
let values = compute_user_position_values(&drift_user, &spot_markets, &prices).unwrap();
assert_eq!(values.len(), 1);
assert_eq!(values[0], 100);
}
#[test]
fn compute_position_values_skips_missing_market() {
let drift_user = make_drift_user(vec![make_spot_position(99, 1_000_000)]);
let spot_markets = HashMap::new();
let mut prices = HashMap::new();
prices.insert(99, 1.0);
let values = compute_user_position_values(&drift_user, &spot_markets, &prices).unwrap();
assert!(values.is_empty());
}
#[test]
fn compute_position_values_skips_missing_price() {
let drift_user = make_drift_user(vec![make_spot_position(0, 1_000_000)]);
let mut spot_markets = HashMap::new();
spot_markets.insert(0, make_spot_market(0, 6));
let prices = HashMap::new();
let values = compute_user_position_values(&drift_user, &spot_markets, &prices).unwrap();
assert!(values.is_empty());
}
#[test]
fn compute_asset_data_returns_tuples() {
let drift_user = make_drift_user(vec![make_spot_position(0, 5_000_000)]);
let mut spot_markets = HashMap::new();
spot_markets.insert(0, make_spot_market(0, 6));
let mut prices = HashMap::new();
prices.insert(0, 1.0);
let data = compute_user_asset_data(&drift_user, &spot_markets, &prices).unwrap();
assert_eq!(data.len(), 1);
let (market_index, token_balance, value_cents) = data[0];
assert_eq!(market_index, 0);
assert_eq!(token_balance, 5_000_000);
assert_eq!(value_cents, 500);
}
#[test]
fn compute_position_values_delegates_to_user_variant() {
let drift_user = Cache {
account: make_drift_user(vec![make_spot_position(0, 1_000_000)]),
last_updated_slot: 12345,
};
let mut spot_markets = HashMap::new();
spot_markets.insert(0, make_spot_market(0, 6));
let mut prices = HashMap::new();
prices.insert(0, 1.0);
let vpd = VaultPositionData {
drift_user,
spot_markets,
prices,
};
let values = compute_position_values(&vpd).unwrap();
assert_eq!(values.len(), 1);
assert_eq!(values[0], 100);
}
}