use std::sync::Arc;
use solana_rpc_client::nonblocking::rpc_client::RpcClient;
use solana_sdk::clock::DEFAULT_MS_PER_SLOT;
use tokio::{
sync::watch,
task::JoinHandle,
time::{sleep, Duration},
};
use tracing::warn;
use crate::batch_client::messages::BlockMessage;
pub fn spawn_block_watcher(
blockdata_tx: watch::Sender<BlockMessage>,
rpc_client: Arc<RpcClient>,
) -> JoinHandle<()> {
tokio::spawn(async move {
let mut last_update = BlockMessage::default();
loop {
sleep(Duration::from_millis(DEFAULT_MS_PER_SLOT)).await;
let Ok((blockhash, last_valid_block_height)) = rpc_client
.get_latest_blockhash_with_commitment(rpc_client.commitment())
.await
else {
warn!("failed to get latest blockhash");
continue;
};
let Ok(epoch_info) = rpc_client
.get_epoch_info_with_commitment(rpc_client.commitment())
.await
else {
warn!("failed to get epoch info");
continue;
};
let block_height = epoch_info.block_height;
let new_update = BlockMessage {
blockhash,
last_valid_block_height,
block_height,
};
if new_update != last_update {
last_update = new_update;
if blockdata_tx.send(new_update).is_err() {
warn!("no receivers for block updates, shutting down block watcher");
break;
}
} else if blockdata_tx.is_closed() {
break;
}
}
warn!("shutting down block watcher");
})
}
#[cfg(test)]
mod tests {
use std::cmp::min;
use async_trait::async_trait;
use solana_client::{
rpc_client::RpcClientConfig,
rpc_request::RpcRequest,
rpc_response::{Response, RpcBlockhash, RpcResponseContext},
rpc_sender::{RpcSender, RpcTransportStats},
};
use solana_rpc_client::mock_sender::MockSender;
use solana_rpc_client_api::client_error::Result as SolanaResult;
use solana_sdk::{epoch_info::EpochInfo, hash::Hash};
use tokio::time::Instant;
use tracing::Level;
use super::*;
#[tokio::test(start_paused = true)]
async fn test_block_watcher() {
let _ = tracing_subscriber::fmt()
.with_max_level(Level::TRACE)
.try_init();
let initial_time = Instant::now();
let initial_value = BlockMessage {
blockhash: Hash::default(),
last_valid_block_height: 1234,
block_height: 5678,
};
let (tx, mut rx) = watch::channel(initial_value);
let client = Arc::new(RpcClient::new_sender(
MockBlockSender {
sender: MockSender::new("succeeds"),
initial_time,
max_slot: 3,
},
RpcClientConfig::default(),
));
let handle = spawn_block_watcher(tx, client);
assert_eq!(*rx.borrow_and_update(), initial_value);
tokio::time::sleep_until(initial_time + Duration::from_millis(DEFAULT_MS_PER_SLOT / 2))
.await;
assert_eq!(*rx.borrow_and_update(), initial_value);
tokio::time::sleep_until(initial_time + Duration::from_millis(DEFAULT_MS_PER_SLOT + 1))
.await;
assert_eq!(
*rx.borrow_and_update(),
BlockMessage {
blockhash: Hash::default(),
last_valid_block_height: 151,
block_height: 1
}
);
tokio::time::sleep_until(initial_time + Duration::from_millis(3 * DEFAULT_MS_PER_SLOT + 1))
.await;
assert_eq!(
*rx.borrow_and_update(),
BlockMessage {
blockhash: Hash::default(),
last_valid_block_height: 153,
block_height: 3
}
);
tokio::time::timeout_at(
initial_time + Duration::from_millis(6 * DEFAULT_MS_PER_SLOT + 1),
rx.changed(),
)
.await
.unwrap_err();
drop(rx);
handle.await.unwrap();
}
struct MockBlockSender {
sender: MockSender,
initial_time: Instant,
max_slot: u64,
}
#[async_trait]
impl RpcSender for MockBlockSender {
async fn send(
&self,
request: RpcRequest,
params: serde_json::Value,
) -> SolanaResult<serde_json::Value> {
let slot = (Instant::now().duration_since(self.initial_time).as_millis()
/ DEFAULT_MS_PER_SLOT as u128) as u64;
let slot = min(slot, self.max_slot);
if let RpcRequest::GetLatestBlockhash = request {
Ok(serde_json::to_value(Response {
context: RpcResponseContext {
slot,
api_version: None,
},
value: RpcBlockhash {
blockhash: Hash::default().to_string(),
last_valid_block_height: slot + 150,
},
})?)
} else if let RpcRequest::GetEpochInfo = request {
Ok(serde_json::to_value(EpochInfo {
epoch: 0,
slot_index: slot,
slots_in_epoch: 256,
absolute_slot: slot,
block_height: slot,
transaction_count: Some(123),
})?)
} else {
self.sender.send(request, params).await
}
}
fn get_transport_stats(&self) -> RpcTransportStats {
self.sender.get_transport_stats()
}
fn url(&self) -> String {
self.sender.url()
}
}
}