drift_rs/
blockhash_subscriber.rs1use std::{
2 collections::VecDeque,
3 sync::{Arc, Mutex, RwLock},
4 time::Duration,
5};
6
7use log::warn;
8use solana_rpc_client::nonblocking::rpc_client::RpcClient;
9use solana_sdk::hash::Hash;
10use tokio::sync::oneshot;
11
12use crate::UnsubHandle;
13
14pub struct BlockhashSubscriber {
16 refresh_frequency: Duration,
17 last_twenty_hashes: Arc<RwLock<VecDeque<Hash>>>,
18 rpc_client: Arc<RpcClient>,
19 unsub: Mutex<Option<UnsubHandle>>,
20}
21
22impl BlockhashSubscriber {
23 pub fn new(refresh_frequency: Duration, rpc_client: Arc<RpcClient>) -> Self {
26 BlockhashSubscriber {
27 last_twenty_hashes: Arc::new(RwLock::new(VecDeque::with_capacity(20))),
28 rpc_client: Arc::clone(&rpc_client),
29 refresh_frequency,
30 unsub: Mutex::default(),
31 }
32 }
33
34 pub fn subscribe(&self) {
36 let (unsub_tx, mut unsub_rx) = oneshot::channel();
37 {
38 let mut guard = self.unsub.try_lock().expect("uncontested");
39 if guard.is_some() {
40 return;
41 }
42 guard.replace(unsub_tx);
43 }
44
45 tokio::spawn({
46 let rpc_client = Arc::clone(&self.rpc_client);
47 let last_twenty_hashes = Arc::clone(&self.last_twenty_hashes);
48 let mut refresh = tokio::time::interval(self.refresh_frequency);
49
50 let max_attempts = 3;
51 let mut attempts = 0;
52 async move {
53 loop {
54 let _ = refresh.tick().await;
55 match rpc_client.get_latest_blockhash().await {
56 Ok(blockhash) => {
57 attempts = 0;
58 let mut hashes = last_twenty_hashes.write().expect("acquired");
59 hashes.push_back(blockhash);
60 if hashes.len() > 20 {
61 let _ = hashes.pop_front();
62 }
63 }
64 Err(err) => {
65 warn!("blockhash subscriber missed update: {err:?}");
66 attempts += 1;
67 if attempts > max_attempts {
68 panic!("unable to fetch blockhash");
69 }
70 }
71 }
72
73 if unsub_rx.try_recv().is_ok() {
74 warn!("unsubscribing from blockhashes");
75 break;
76 }
77 }
78
79 let mut lock = last_twenty_hashes.write().expect("acquired");
80 lock.clear();
81 }
82 });
83 }
84
85 pub fn get_latest_blockhash(&self) -> Option<Hash> {
87 let lock = self.last_twenty_hashes.read().expect("acquired");
88 lock.back().copied()
89 }
90
91 pub fn get_valid_blockhash(&self) -> Option<Hash> {
93 let lock = self.last_twenty_hashes.read().expect("acquired");
94 lock.front().copied()
95 }
96
97 pub fn unsubscribe(&self) {
99 let mut guard = self.unsub.lock().expect("uncontested");
100 if let Some(unsub) = guard.take() {
101 if unsub.send(()).is_err() {
102 log::error!("couldn't unsubscribe");
103 }
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use serde_json::json;
111 use solana_rpc_client::rpc_client::Mocks;
112 use solana_rpc_client_api::request::RpcRequest;
113
114 use super::*;
115
116 #[tokio::test]
117 async fn blockhash_subscriber_updates() {
118 let _ = env_logger::try_init();
119 let mut response_mocks = Mocks::default();
120 let latest_block_hash = Hash::new_unique();
121 let oldest_block_hash = Hash::new_unique();
122
123 response_mocks.insert(
124 RpcRequest::GetLatestBlockhash,
125 json!({
126 "context": {
127 "slot": 12345,
128 },
129 "value": {
130 "blockhash": latest_block_hash.to_string(),
131 "lastValidBlockHeight": 1,
132 }
133 }),
134 );
135
136 let mock_rpc = RpcClient::new_mock_with_mocks(
137 "https://api.mainnet-beta.solana.com".into(),
138 response_mocks,
139 );
140
141 let blockhash_subscriber = BlockhashSubscriber {
142 last_twenty_hashes: Arc::new(RwLock::new(VecDeque::from_iter(
143 [oldest_block_hash]
144 .into_iter()
145 .chain(std::iter::repeat(Hash::new_unique()).take(20)),
146 ))),
147 unsub: Mutex::default(),
148 rpc_client: Arc::new(mock_rpc),
149 refresh_frequency: Duration::from_secs(4),
150 };
151
152 assert_eq!(
154 blockhash_subscriber.get_valid_blockhash().unwrap(),
155 oldest_block_hash
156 );
157 assert!(blockhash_subscriber.get_latest_blockhash().unwrap() != latest_block_hash);
158
159 blockhash_subscriber.subscribe();
161 tokio::time::sleep(Duration::from_secs(2)).await;
162 assert_eq!(
163 blockhash_subscriber.get_latest_blockhash().unwrap(),
164 latest_block_hash
165 );
166
167 assert!(blockhash_subscriber.get_valid_blockhash().unwrap() != oldest_block_hash);
169
170 blockhash_subscriber.unsubscribe();
172 tokio::time::sleep(Duration::from_secs(4)).await;
173 assert!(blockhash_subscriber.get_latest_blockhash().is_none());
174 }
175}