drift_rs/
usermap.rs

1use std::{
2    str::FromStr,
3    sync::{
4        atomic::{AtomicU64, Ordering},
5        Arc, Mutex,
6    },
7};
8
9use anchor_lang::{AccountDeserialize, AnchorDeserialize};
10use dashmap::DashMap;
11use serde_json::json;
12use solana_account_decoder_client_types::UiAccountEncoding;
13use solana_rpc_client::nonblocking::rpc_client::RpcClient;
14use solana_rpc_client_api::{
15    config::{RpcAccountInfoConfig, RpcProgramAccountsConfig},
16    filter::RpcFilterType,
17    request::RpcRequest,
18    response::{OptionalContext, RpcKeyedAccount},
19};
20use solana_sdk::{commitment_config::CommitmentConfig, pubkey::Pubkey};
21
22use crate::{
23    constants,
24    drift_idl::accounts::User,
25    memcmp::{get_non_idle_user_filter, get_user_filter},
26    utils::get_ws_url,
27    websocket_program_account_subscriber::{
28        WebsocketProgramAccountOptions, WebsocketProgramAccountSubscriber,
29    },
30    SdkResult, UnsubHandle,
31};
32
33/// Subscribes to the _all_ Drift users' account updates via Ws program subscribe
34pub struct GlobalUserMap {
35    subscribed: bool,
36    subscription: WebsocketProgramAccountSubscriber,
37    pub usermap: Arc<DashMap<String, User>>,
38    sync_lock: Option<Mutex<()>>,
39    latest_slot: Arc<AtomicU64>,
40    commitment: CommitmentConfig,
41    rpc: RpcClient,
42}
43
44impl GlobalUserMap {
45    pub const SUBSCRIPTION_ID: &'static str = "usermap";
46
47    pub fn new(
48        commitment: CommitmentConfig,
49        endpoint: String,
50        sync: bool,
51        additional_filters: Option<Vec<RpcFilterType>>,
52    ) -> Self {
53        let mut filters = vec![get_user_filter(), get_non_idle_user_filter()];
54        filters.extend(additional_filters.unwrap_or_default());
55        let options = WebsocketProgramAccountOptions {
56            filters,
57            commitment,
58            encoding: UiAccountEncoding::Base64Zstd,
59        };
60        let url = get_ws_url(&endpoint).unwrap();
61
62        let subscription = WebsocketProgramAccountSubscriber::new(url, options);
63
64        let usermap = Arc::new(DashMap::new());
65        let rpc = RpcClient::new_with_commitment(endpoint.clone(), commitment);
66        let sync_lock = if sync { Some(Mutex::new(())) } else { None };
67
68        Self {
69            subscribed: false,
70            subscription,
71            usermap,
72            sync_lock,
73            latest_slot: Arc::new(AtomicU64::new(0)),
74            commitment,
75            rpc,
76        }
77    }
78
79    pub async fn subscribe(&self) -> SdkResult<UnsubHandle> {
80        if self.sync_lock.is_some() {
81            self.sync().await?;
82        }
83
84        let unsub = self
85            .subscription
86            .subscribe::<User, _>(Self::SUBSCRIPTION_ID, {
87                let latest_slot = self.latest_slot.clone();
88                let user_map = self.usermap.clone();
89                move |update| {
90                    if update.data_and_slot.slot > latest_slot.load(Ordering::Relaxed) {
91                        latest_slot.store(update.data_and_slot.slot, Ordering::Relaxed);
92                    }
93                    user_map.insert(update.pubkey.clone(), update.data_and_slot.data);
94                }
95            });
96
97        Ok(unsub)
98    }
99
100    pub fn unsubscribe(self) -> SdkResult<()> {
101        if self.subscribed {
102            self.usermap.clear();
103            self.latest_slot.store(0, Ordering::Relaxed);
104        }
105        Ok(())
106    }
107
108    pub fn size(&self) -> usize {
109        self.usermap.len()
110    }
111
112    pub fn contains(&self, pubkey: &str) -> bool {
113        self.usermap.contains_key(pubkey)
114    }
115
116    pub fn get(&self, pubkey: &str) -> Option<User> {
117        self.usermap.get(pubkey).map(|user| *user.value())
118    }
119
120    pub async fn must_get(&self, pubkey: &str) -> SdkResult<User> {
121        if let Some(user) = self.get(pubkey) {
122            Ok(user)
123        } else {
124            let user_data = self
125                .rpc
126                .get_account_data(&Pubkey::from_str(pubkey).unwrap())
127                .await?;
128            let user = User::deserialize(&mut user_data.as_slice()).unwrap();
129            self.usermap.insert(pubkey.to_string(), user);
130            Ok(self.get(pubkey).unwrap())
131        }
132    }
133
134    #[allow(clippy::await_holding_lock)]
135    pub async fn sync(&self) -> SdkResult<()> {
136        let sync_lock = self.sync_lock.as_ref().expect("expected sync lock");
137
138        let _lock = match sync_lock.try_lock() {
139            Ok(lock) => lock,
140            Err(_) => return Ok(()),
141        };
142
143        let account_config = RpcAccountInfoConfig {
144            commitment: Some(self.commitment),
145            encoding: Some(self.subscription.options.encoding),
146            ..RpcAccountInfoConfig::default()
147        };
148
149        let gpa_config = RpcProgramAccountsConfig {
150            filters: Some(self.subscription.options.filters.clone()),
151            account_config,
152            with_context: Some(true),
153            sort_results: None,
154        };
155
156        let response = self
157            .rpc
158            .send::<OptionalContext<Vec<RpcKeyedAccount>>>(
159                RpcRequest::GetProgramAccounts,
160                json!([constants::PROGRAM_ID.to_string(), gpa_config]),
161            )
162            .await?;
163
164        if let OptionalContext::Context(accounts) = response {
165            for account in accounts.value {
166                let pubkey = account.pubkey;
167                let user_data = account.account.data.decode().expect("User data");
168                let data = User::try_deserialize_unchecked(&mut user_data.as_slice())
169                    .expect("User desrializes");
170                self.usermap.insert(pubkey, data);
171            }
172
173            self.latest_slot
174                .store(accounts.context.slot, Ordering::Relaxed);
175        }
176
177        Ok(())
178    }
179
180    pub fn get_latest_slot(&self) -> u64 {
181        self.latest_slot.load(Ordering::Relaxed)
182    }
183}
184
185#[cfg(feature = "rpc_tests")]
186mod tests {
187    use crate::utils::test_envs::mainnet_endpoint;
188
189    #[tokio::test]
190    async fn test_usermap() {
191        use solana_sdk::commitment_config::{CommitmentConfig, CommitmentLevel};
192
193        use crate::usermap::GlobalUserMap;
194
195        let commitment = CommitmentConfig {
196            commitment: CommitmentLevel::Processed,
197        };
198
199        let mut usermap = GlobalUserMap::new(commitment, mainnet_endpoint(), true);
200        usermap.subscribe().await.unwrap();
201
202        tokio::time::sleep(tokio::time::Duration::from_secs(30)).await;
203
204        dbg!(usermap.size());
205        assert!(usermap.size() > 50000);
206
207        dbg!(usermap.get_latest_slot());
208
209        usermap.unsubscribe().await.unwrap();
210
211        tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
212
213        assert_eq!(usermap.size(), 0);
214        assert_eq!(usermap.subscribed, false);
215    }
216}