drift_rs/
account_map.rs

1//! Hybrid solana account map backed by Ws or RPC polling
2use std::{
3    sync::{Arc, Mutex},
4    time::Duration,
5};
6
7use anchor_lang::AccountDeserialize;
8use dashmap::DashMap;
9use drift_pubsub_client::PubsubClient;
10use log::debug;
11use solana_rpc_client::nonblocking::rpc_client::RpcClient;
12use solana_sdk::{clock::Slot, commitment_config::CommitmentConfig, pubkey::Pubkey};
13
14use crate::{
15    grpc::AccountUpdate, polled_account_subscriber::PolledAccountSubscriber, types::DataAndSlot,
16    websocket_account_subscriber::WebsocketAccountSubscriber, SdkResult, UnsubHandle,
17};
18
19const LOG_TARGET: &str = "accountmap";
20
21#[derive(Clone, Default)]
22pub struct AccountSlot {
23    raw: Vec<u8>,
24    slot: Slot,
25}
26
27/// Set of subscriptions to network accounts
28///
29/// Accounts are subscribed by either Ws or polling at fixed intervals
30pub struct AccountMap {
31    pubsub: Arc<PubsubClient>,
32    rpc: Arc<RpcClient>,
33    commitment: CommitmentConfig,
34    inner: Arc<DashMap<Pubkey, AccountSlot, ahash::RandomState>>,
35    subscriptions: Arc<DashMap<Pubkey, AccountSub<Subscribed>, ahash::RandomState>>,
36}
37
38impl AccountMap {
39    pub fn new(
40        pubsub: Arc<PubsubClient>,
41        rpc: Arc<RpcClient>,
42        commitment: CommitmentConfig,
43    ) -> Self {
44        Self {
45            pubsub,
46            rpc,
47            commitment,
48            inner: Arc::default(),
49            subscriptions: Arc::default(),
50        }
51    }
52    /// Subscribe account with Ws
53    ///
54    /// * `account` pubkey to subscribe
55    ///
56    pub async fn subscribe_account(&self, account: &Pubkey) -> SdkResult<()> {
57        if self.inner.contains_key(account) {
58            return Ok(());
59        }
60        debug!(target: LOG_TARGET, "subscribing: {account:?}");
61
62        let user = AccountSub::new(Arc::clone(&self.pubsub), self.commitment, *account);
63        let sub = user.subscribe(Arc::clone(&self.inner)).await?;
64        self.subscriptions.insert(*account, sub);
65
66        Ok(())
67    }
68    /// Subscribe account with RPC polling
69    ///
70    /// * `account` pubkey to subscribe
71    /// * `interval` to poll the account
72    ///
73    pub async fn subscribe_account_polled(
74        &self,
75        account: &Pubkey,
76        interval: Option<Duration>,
77    ) -> SdkResult<()> {
78        if self.inner.contains_key(account) {
79            return Ok(());
80        }
81        debug!(target: LOG_TARGET, "subscribing: {account:?} @ {interval:?}");
82
83        let user = AccountSub::polled(Arc::clone(&self.rpc), *account, interval);
84        let sub = user.subscribe(Arc::clone(&self.inner)).await?;
85        self.subscriptions.insert(*account, sub);
86
87        Ok(())
88    }
89    /// On account update callback for gRPC hook
90    pub(crate) fn on_account_fn(&self) -> impl Fn(&AccountUpdate) {
91        let accounts = Arc::clone(&self.inner);
92        let subscriptions = Arc::clone(&self.subscriptions);
93        move |update| {
94            accounts
95                .entry(update.pubkey)
96                .and_modify(|x| {
97                    x.slot = update.slot;
98                    x.raw.resize(update.data.len(), 0);
99                    x.raw.clone_from_slice(update.data);
100                    if update.lamports == 0 {
101                        accounts.remove(&update.pubkey);
102                    }
103                })
104                .or_insert({
105                    subscriptions.insert(
106                        update.pubkey,
107                        AccountSub {
108                            pubkey: update.pubkey,
109                            subscription: SubscriptionImpl::Grpc,
110                            state: Subscribed {
111                                unsub: Mutex::default(),
112                            },
113                        },
114                    );
115                    AccountSlot {
116                        slot: update.slot,
117                        raw: update.data.to_vec(),
118                    }
119                });
120        }
121    }
122    /// Unsubscribe user account
123    pub fn unsubscribe_account(&self, account: &Pubkey) {
124        if let Some((acc, sub)) = self.subscriptions.remove(account) {
125            debug!(target: LOG_TARGET, "unsubscribing: {acc:?}");
126            self.inner.remove(account);
127            let _ = sub.unsubscribe();
128        }
129    }
130    /// Return data of the given `account` as T, if it exists
131    pub fn account_data<T: AccountDeserialize>(&self, account: &Pubkey) -> Option<T> {
132        self.account_data_and_slot(account).map(|x| x.data)
133    }
134    /// Return data of the given `account` as T and slot, if it exists
135    pub fn account_data_and_slot<T: AccountDeserialize>(
136        &self,
137        account: &Pubkey,
138    ) -> Option<DataAndSlot<T>> {
139        self.inner.get(account).map(|x| DataAndSlot {
140            slot: x.slot,
141            data: T::try_deserialize_unchecked(&mut x.raw.as_slice()).expect("deserializes"),
142        })
143    }
144}
145
146struct Subscribed {
147    unsub: Mutex<Option<UnsubHandle>>,
148}
149struct Unsubscribed;
150
151/// A subscription to a solana account
152pub struct AccountSub<S> {
153    /// account pubkey
154    pub pubkey: Pubkey,
155    /// underlying subscription
156    subscription: SubscriptionImpl,
157    /// subscription state
158    state: S,
159}
160
161impl AccountSub<Unsubscribed> {
162    pub const SUBSCRIPTION_ID: &'static str = "account";
163
164    /// Create a new Ws account subscriber
165    pub fn new(pubsub: Arc<PubsubClient>, commitment: CommitmentConfig, pubkey: Pubkey) -> Self {
166        let subscription = WebsocketAccountSubscriber::new(pubsub, pubkey, commitment);
167
168        Self {
169            pubkey,
170            subscription: SubscriptionImpl::Ws(subscription),
171            state: Unsubscribed {},
172        }
173    }
174
175    /// Create a new polled account subscriber
176    pub fn polled(rpc: Arc<RpcClient>, pubkey: Pubkey, interval: Option<Duration>) -> Self {
177        let subscription =
178            PolledAccountSubscriber::new(pubkey, interval.unwrap_or(Duration::from_secs(5)), rpc);
179
180        Self {
181            pubkey,
182            subscription: SubscriptionImpl::Polled(subscription),
183            state: Unsubscribed {},
184        }
185    }
186
187    /// Start the subscriber task
188    pub async fn subscribe(
189        self,
190        accounts: Arc<DashMap<Pubkey, AccountSlot, ahash::RandomState>>,
191    ) -> SdkResult<AccountSub<Subscribed>> {
192        let unsub = match self.subscription {
193            SubscriptionImpl::Ws(ref ws) => {
194                let unsub = ws
195                    .subscribe(Self::SUBSCRIPTION_ID, true, move |update| {
196                        accounts
197                            .entry(update.pubkey)
198                            .and_modify(|x| {
199                                x.slot = update.slot;
200                                x.raw.clone_from(&update.data);
201                                if update.lamports == 0 {
202                                    accounts.remove(&update.pubkey);
203                                }
204                            })
205                            .or_insert(AccountSlot {
206                                raw: update.data.clone(),
207                                slot: update.slot,
208                            });
209                    })
210                    .await?;
211                Some(unsub)
212            }
213            SubscriptionImpl::Polled(ref poll) => {
214                let unsub = poll.subscribe(move |update| {
215                    accounts
216                        .entry(update.pubkey)
217                        .and_modify(|x| {
218                            x.slot = update.slot;
219                            x.raw.clone_from(&update.data);
220                            if update.lamports == 0 {
221                                accounts.remove(&update.pubkey);
222                            }
223                        })
224                        .or_insert(AccountSlot {
225                            raw: update.data.clone(),
226                            slot: update.slot,
227                        });
228                });
229                Some(unsub)
230            }
231            SubscriptionImpl::Grpc => None,
232        };
233
234        Ok(AccountSub {
235            pubkey: self.pubkey,
236            subscription: self.subscription,
237            state: Subscribed {
238                unsub: Mutex::new(unsub),
239            },
240        })
241    }
242}
243
244impl AccountSub<Subscribed> {
245    /// Stop the user subscriber task, if it exists
246    pub fn unsubscribe(self) -> AccountSub<Unsubscribed> {
247        let mut guard = self.state.unsub.lock().expect("acquire");
248        if let Some(unsub) = guard.take() {
249            if unsub.send(()).is_err() {
250                log::error!("couldn't unsubscribe");
251            }
252        }
253
254        AccountSub {
255            pubkey: self.pubkey,
256            subscription: self.subscription,
257            state: Unsubscribed,
258        }
259    }
260}
261
262enum SubscriptionImpl {
263    Ws(WebsocketAccountSubscriber),
264    Polled(PolledAccountSubscriber),
265    Grpc,
266}
267
268#[cfg(test)]
269mod tests {
270    use std::time::Duration;
271
272    use solana_sdk::pubkey;
273
274    use super::*;
275    use crate::{
276        accounts::User,
277        constants::{state_account, DEFAULT_PUBKEY},
278        types::accounts::State,
279        utils::{get_ws_url, test_envs::mainnet_endpoint},
280        Wallet,
281    };
282
283    #[tokio::test]
284    async fn test_user_subscribe() {
285        let _ = env_logger::try_init();
286        let pubsub = Arc::new(
287            PubsubClient::new(&get_ws_url(&mainnet_endpoint()).unwrap())
288                .await
289                .expect("ws connects"),
290        );
291        let rpc = Arc::new(RpcClient::new(mainnet_endpoint()));
292        let account_map = AccountMap::new(pubsub, rpc, CommitmentConfig::confirmed());
293        let user_1 = Wallet::derive_user_account(
294            &pubkey!("DxoRJ4f5XRMvXU9SGuM4ZziBFUxbhB3ubur5sVZEvue2"),
295            0,
296        );
297        let user_2 = Wallet::derive_user_account(
298            &pubkey!("Drift7AMLeq3FoKBMpT9wzqyMM3HVvvZFtsn81iSSkWV"),
299            0,
300        );
301
302        let (res1, res2, res3) = tokio::join!(
303            account_map.subscribe_account(&user_1),
304            account_map.subscribe_account(&user_2),
305            account_map.subscribe_account_polled(state_account(), Some(Duration::from_secs(2))),
306        );
307        assert!(res1.and(res2).and(res3).is_ok());
308
309        let handle = tokio::spawn(async move {
310            tokio::time::sleep(Duration::from_secs(8)).await;
311            let account_data = account_map.account_data::<User>(&user_1);
312            assert!(account_data.is_some_and(|x| x.authority != DEFAULT_PUBKEY));
313            account_map.unsubscribe_account(&user_1);
314
315            let account_data = account_map.account_data::<User>(&user_1);
316            assert!(account_data.is_none());
317
318            let account_data = account_map.account_data::<User>(&user_2);
319            assert!(account_data.is_some_and(|x| x.authority != DEFAULT_PUBKEY));
320
321            let state_account = account_map.account_data::<State>(state_account());
322            assert!(state_account.is_some());
323        });
324
325        assert!(handle.await.is_ok());
326    }
327}