1use std::{
3 sync::{Arc, Mutex},
4 time::Duration,
5};
6
7use anchor_lang::AccountDeserialize;
8use dashmap::DashMap;
9use drift_pubsub_client::PubsubClient;
10use log::debug;
11use solana_rpc_client::nonblocking::rpc_client::RpcClient;
12use solana_sdk::{clock::Slot, commitment_config::CommitmentConfig, pubkey::Pubkey};
13
14use crate::{
15 grpc::AccountUpdate, polled_account_subscriber::PolledAccountSubscriber, types::DataAndSlot,
16 websocket_account_subscriber::WebsocketAccountSubscriber, SdkResult, UnsubHandle,
17};
18
19const LOG_TARGET: &str = "accountmap";
20
21#[derive(Clone, Default)]
22pub struct AccountSlot {
23 raw: Vec<u8>,
24 slot: Slot,
25}
26
27pub struct AccountMap {
31 pubsub: Arc<PubsubClient>,
32 rpc: Arc<RpcClient>,
33 commitment: CommitmentConfig,
34 inner: Arc<DashMap<Pubkey, AccountSlot, ahash::RandomState>>,
35 subscriptions: Arc<DashMap<Pubkey, AccountSub<Subscribed>, ahash::RandomState>>,
36}
37
38impl AccountMap {
39 pub fn new(
40 pubsub: Arc<PubsubClient>,
41 rpc: Arc<RpcClient>,
42 commitment: CommitmentConfig,
43 ) -> Self {
44 Self {
45 pubsub,
46 rpc,
47 commitment,
48 inner: Arc::default(),
49 subscriptions: Arc::default(),
50 }
51 }
52 pub async fn subscribe_account(&self, account: &Pubkey) -> SdkResult<()> {
57 if self.inner.contains_key(account) {
58 return Ok(());
59 }
60 debug!(target: LOG_TARGET, "subscribing: {account:?}");
61
62 let user = AccountSub::new(Arc::clone(&self.pubsub), self.commitment, *account);
63 let sub = user.subscribe(Arc::clone(&self.inner)).await?;
64 self.subscriptions.insert(*account, sub);
65
66 Ok(())
67 }
68 pub async fn subscribe_account_polled(
74 &self,
75 account: &Pubkey,
76 interval: Option<Duration>,
77 ) -> SdkResult<()> {
78 if self.inner.contains_key(account) {
79 return Ok(());
80 }
81 debug!(target: LOG_TARGET, "subscribing: {account:?} @ {interval:?}");
82
83 let user = AccountSub::polled(Arc::clone(&self.rpc), *account, interval);
84 let sub = user.subscribe(Arc::clone(&self.inner)).await?;
85 self.subscriptions.insert(*account, sub);
86
87 Ok(())
88 }
89 pub(crate) fn on_account_fn(&self) -> impl Fn(&AccountUpdate) {
91 let accounts = Arc::clone(&self.inner);
92 let subscriptions = Arc::clone(&self.subscriptions);
93 move |update| {
94 accounts
95 .entry(update.pubkey)
96 .and_modify(|x| {
97 x.slot = update.slot;
98 x.raw.resize(update.data.len(), 0);
99 x.raw.clone_from_slice(update.data);
100 if update.lamports == 0 {
101 accounts.remove(&update.pubkey);
102 }
103 })
104 .or_insert({
105 subscriptions.insert(
106 update.pubkey,
107 AccountSub {
108 pubkey: update.pubkey,
109 subscription: SubscriptionImpl::Grpc,
110 state: Subscribed {
111 unsub: Mutex::default(),
112 },
113 },
114 );
115 AccountSlot {
116 slot: update.slot,
117 raw: update.data.to_vec(),
118 }
119 });
120 }
121 }
122 pub fn unsubscribe_account(&self, account: &Pubkey) {
124 if let Some((acc, sub)) = self.subscriptions.remove(account) {
125 debug!(target: LOG_TARGET, "unsubscribing: {acc:?}");
126 self.inner.remove(account);
127 let _ = sub.unsubscribe();
128 }
129 }
130 pub fn account_data<T: AccountDeserialize>(&self, account: &Pubkey) -> Option<T> {
132 self.account_data_and_slot(account).map(|x| x.data)
133 }
134 pub fn account_data_and_slot<T: AccountDeserialize>(
136 &self,
137 account: &Pubkey,
138 ) -> Option<DataAndSlot<T>> {
139 self.inner.get(account).map(|x| DataAndSlot {
140 slot: x.slot,
141 data: T::try_deserialize_unchecked(&mut x.raw.as_slice()).expect("deserializes"),
142 })
143 }
144}
145
146struct Subscribed {
147 unsub: Mutex<Option<UnsubHandle>>,
148}
149struct Unsubscribed;
150
151pub struct AccountSub<S> {
153 pub pubkey: Pubkey,
155 subscription: SubscriptionImpl,
157 state: S,
159}
160
161impl AccountSub<Unsubscribed> {
162 pub const SUBSCRIPTION_ID: &'static str = "account";
163
164 pub fn new(pubsub: Arc<PubsubClient>, commitment: CommitmentConfig, pubkey: Pubkey) -> Self {
166 let subscription = WebsocketAccountSubscriber::new(pubsub, pubkey, commitment);
167
168 Self {
169 pubkey,
170 subscription: SubscriptionImpl::Ws(subscription),
171 state: Unsubscribed {},
172 }
173 }
174
175 pub fn polled(rpc: Arc<RpcClient>, pubkey: Pubkey, interval: Option<Duration>) -> Self {
177 let subscription =
178 PolledAccountSubscriber::new(pubkey, interval.unwrap_or(Duration::from_secs(5)), rpc);
179
180 Self {
181 pubkey,
182 subscription: SubscriptionImpl::Polled(subscription),
183 state: Unsubscribed {},
184 }
185 }
186
187 pub async fn subscribe(
189 self,
190 accounts: Arc<DashMap<Pubkey, AccountSlot, ahash::RandomState>>,
191 ) -> SdkResult<AccountSub<Subscribed>> {
192 let unsub = match self.subscription {
193 SubscriptionImpl::Ws(ref ws) => {
194 let unsub = ws
195 .subscribe(Self::SUBSCRIPTION_ID, true, move |update| {
196 accounts
197 .entry(update.pubkey)
198 .and_modify(|x| {
199 x.slot = update.slot;
200 x.raw.clone_from(&update.data);
201 if update.lamports == 0 {
202 accounts.remove(&update.pubkey);
203 }
204 })
205 .or_insert(AccountSlot {
206 raw: update.data.clone(),
207 slot: update.slot,
208 });
209 })
210 .await?;
211 Some(unsub)
212 }
213 SubscriptionImpl::Polled(ref poll) => {
214 let unsub = poll.subscribe(move |update| {
215 accounts
216 .entry(update.pubkey)
217 .and_modify(|x| {
218 x.slot = update.slot;
219 x.raw.clone_from(&update.data);
220 if update.lamports == 0 {
221 accounts.remove(&update.pubkey);
222 }
223 })
224 .or_insert(AccountSlot {
225 raw: update.data.clone(),
226 slot: update.slot,
227 });
228 });
229 Some(unsub)
230 }
231 SubscriptionImpl::Grpc => None,
232 };
233
234 Ok(AccountSub {
235 pubkey: self.pubkey,
236 subscription: self.subscription,
237 state: Subscribed {
238 unsub: Mutex::new(unsub),
239 },
240 })
241 }
242}
243
244impl AccountSub<Subscribed> {
245 pub fn unsubscribe(self) -> AccountSub<Unsubscribed> {
247 let mut guard = self.state.unsub.lock().expect("acquire");
248 if let Some(unsub) = guard.take() {
249 if unsub.send(()).is_err() {
250 log::error!("couldn't unsubscribe");
251 }
252 }
253
254 AccountSub {
255 pubkey: self.pubkey,
256 subscription: self.subscription,
257 state: Unsubscribed,
258 }
259 }
260}
261
262enum SubscriptionImpl {
263 Ws(WebsocketAccountSubscriber),
264 Polled(PolledAccountSubscriber),
265 Grpc,
266}
267
268#[cfg(test)]
269mod tests {
270 use std::time::Duration;
271
272 use solana_sdk::pubkey;
273
274 use super::*;
275 use crate::{
276 accounts::User,
277 constants::{state_account, DEFAULT_PUBKEY},
278 types::accounts::State,
279 utils::{get_ws_url, test_envs::mainnet_endpoint},
280 Wallet,
281 };
282
283 #[tokio::test]
284 async fn test_user_subscribe() {
285 let _ = env_logger::try_init();
286 let pubsub = Arc::new(
287 PubsubClient::new(&get_ws_url(&mainnet_endpoint()).unwrap())
288 .await
289 .expect("ws connects"),
290 );
291 let rpc = Arc::new(RpcClient::new(mainnet_endpoint()));
292 let account_map = AccountMap::new(pubsub, rpc, CommitmentConfig::confirmed());
293 let user_1 = Wallet::derive_user_account(
294 &pubkey!("DxoRJ4f5XRMvXU9SGuM4ZziBFUxbhB3ubur5sVZEvue2"),
295 0,
296 );
297 let user_2 = Wallet::derive_user_account(
298 &pubkey!("Drift7AMLeq3FoKBMpT9wzqyMM3HVvvZFtsn81iSSkWV"),
299 0,
300 );
301
302 let (res1, res2, res3) = tokio::join!(
303 account_map.subscribe_account(&user_1),
304 account_map.subscribe_account(&user_2),
305 account_map.subscribe_account_polled(state_account(), Some(Duration::from_secs(2))),
306 );
307 assert!(res1.and(res2).and(res3).is_ok());
308
309 let handle = tokio::spawn(async move {
310 tokio::time::sleep(Duration::from_secs(8)).await;
311 let account_data = account_map.account_data::<User>(&user_1);
312 assert!(account_data.is_some_and(|x| x.authority != DEFAULT_PUBKEY));
313 account_map.unsubscribe_account(&user_1);
314
315 let account_data = account_map.account_data::<User>(&user_1);
316 assert!(account_data.is_none());
317
318 let account_data = account_map.account_data::<User>(&user_2);
319 assert!(account_data.is_some_and(|x| x.authority != DEFAULT_PUBKEY));
320
321 let state_account = account_map.account_data::<State>(state_account());
322 assert!(state_account.is_some());
323 });
324
325 assert!(handle.await.is_ok());
326 }
327}