drift_rs/
websocket_program_account_subscriber.rs

1use std::time::Instant;
2
3use anchor_lang::AnchorDeserialize;
4use drift_pubsub_client::PubsubClient;
5use futures_util::StreamExt;
6use log::warn;
7use solana_account_decoder_client_types::UiAccountEncoding;
8use solana_rpc_client_api::{
9    config::{RpcAccountInfoConfig, RpcProgramAccountsConfig},
10    filter::RpcFilterType,
11};
12use solana_sdk::commitment_config::CommitmentConfig;
13use tokio::sync::oneshot;
14
15use crate::{constants, types::DataAndSlot, UnsubHandle};
16
17#[derive(Clone, Debug)]
18pub struct ProgramAccountUpdate<T: AnchorDeserialize + Send> {
19    pub pubkey: String,
20    pub data_and_slot: DataAndSlot<T>,
21    pub now: Instant,
22}
23
24impl<T: AnchorDeserialize + Send> ProgramAccountUpdate<T> {
25    pub fn new(pubkey: String, data_and_slot: DataAndSlot<T>, now: Instant) -> Self {
26        Self {
27            pubkey,
28            data_and_slot,
29            now,
30        }
31    }
32}
33
34#[derive(Clone)]
35pub struct WebsocketProgramAccountOptions {
36    pub filters: Vec<RpcFilterType>,
37    pub commitment: CommitmentConfig,
38    pub encoding: UiAccountEncoding,
39}
40
41pub struct WebsocketProgramAccountSubscriber {
42    url: String,
43    pub options: WebsocketProgramAccountOptions,
44}
45
46impl WebsocketProgramAccountSubscriber {
47    pub fn new(url: String, options: WebsocketProgramAccountOptions) -> Self {
48        WebsocketProgramAccountSubscriber { url, options }
49    }
50
51    /// Start a GPA subscription task
52    ///
53    /// `subscription_name` some user defined identifier for the subscription
54    /// `on_update` handles updates from the subscription task
55    pub fn subscribe<T, F>(&self, subscription_name: &'static str, on_update: F) -> UnsubHandle
56    where
57        T: AnchorDeserialize + Clone + Send + 'static,
58        F: 'static + Send + Fn(&ProgramAccountUpdate<T>),
59    {
60        let account_config = RpcAccountInfoConfig {
61            commitment: Some(self.options.commitment),
62            encoding: Some(self.options.encoding),
63            ..Default::default()
64        };
65        let config = RpcProgramAccountsConfig {
66            filters: Some(self.options.filters.clone()),
67            account_config,
68            ..Default::default()
69        };
70
71        let (unsub_tx, mut unsub_rx) = oneshot::channel::<()>();
72        let url = self.url.clone();
73
74        tokio::spawn(async move {
75            let mut latest_slot = 0;
76            loop {
77                let pubsub = match PubsubClient::new(&url).await {
78                    Ok(pubsub) => pubsub,
79                    Err(err) => {
80                        log::error!("GPA stream connect failed: {err:?}");
81                        continue;
82                    }
83                };
84                let (mut accounts, unsub) = match pubsub
85                    .program_subscribe(&constants::PROGRAM_ID, Some(config.clone()))
86                    .await
87                {
88                    Ok(res) => res,
89                    Err(err) => {
90                        log::error!("GPA stream subscribe failed: {err:?}");
91                        continue;
92                    }
93                };
94
95                let res = loop {
96                    tokio::select! {
97                        biased;
98                        message = accounts.next() => {
99                            match message {
100                                Some(message) => {
101                                    let slot = message.context.slot;
102                                    if slot >= latest_slot {
103                                        latest_slot = slot;
104                                        let pubkey = message.value.pubkey;
105                                        let data = &message.value.account.data.decode().expect("account has data");
106                                        let data = T::deserialize(&mut &data[8..]).expect("deserializes T");
107                                        on_update(&ProgramAccountUpdate::new(pubkey, DataAndSlot::<T> { slot, data }, Instant::now()));
108                                    }
109                                },
110                                None => {
111                                    log::error!("{subscription_name}: Ws GPA stream ended unexpectedly");
112                                    break Err(());
113                                }
114                            }
115                        }
116                        _ = &mut unsub_rx => {
117                            warn!("unsubscribing: {subscription_name}");
118                            unsub().await;
119                            break Ok(());
120                        }
121                    }
122                };
123                if res.is_ok() {
124                    break;
125                }
126            }
127        });
128
129        unsub_tx
130    }
131}
132
133#[cfg(feature = "rpc_tests")]
134mod tests {
135    use super::*;
136    use crate::{
137        drift_idl::accounts::User,
138        memcmp::{get_non_idle_user_filter, get_user_filter},
139        utils::test_envs::mainnet_endpoint,
140    };
141
142    #[tokio::test]
143    async fn test_subscribe() {
144        let filters = vec![get_user_filter(), get_non_idle_user_filter()];
145        let commitment = CommitmentConfig::confirmed();
146        let options = WebsocketProgramAccountOptions {
147            filters,
148            commitment,
149            encoding: UiAccountEncoding::Base64,
150        };
151        let subscription_name = "Test";
152
153        let mut ws_subscriber = WebsocketProgramAccountSubscriber::<User>::new(
154            subscription_name,
155            mainnet_endpoint(),
156            options,
157            EventEmitter::new(),
158        );
159
160        let _ = ws_subscriber.subscribe().await;
161        dbg!("sub'd");
162
163        ws_subscriber.event_emitter.clone().subscribe(move |event| {
164            dbg!(event);
165        });
166
167        tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
168        let _ = ws_subscriber.unsubscribe().await;
169        dbg!("unsub'd");
170    }
171}