use async_trait::async_trait;
use anchor_lang::Event;
use solana_sdk::commitment_config::CommitmentConfig;
use solana_client::nonblocking::pubsub_client;
use solana_client::rpc_config::{RpcTransactionLogsConfig, RpcTransactionLogsFilter};
use solana_program::pubkey::Pubkey;
use crate::OnDemandError;
use base64::engine::general_purpose;
use base64::Engine;
use futures::StreamExt;
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
#[async_trait]
pub trait EventSenderTrait<E: Event + Send + Sync + 'static>: Send + Sync {
async fn send(&self, event: E) -> Result<(), OnDemandError>;
}
#[async_trait]
impl<E> EventSenderTrait<E> for mpsc::Sender<E>
where
E: Event + Send + Sync + 'static,
{
async fn send(&self, event: E) -> Result<(), OnDemandError> {
self.send(event).await.map_err(|_e| OnDemandError::NetworkError)
}
}
#[async_trait]
impl<E> EventSenderTrait<E> for mpsc::UnboundedSender<E>
where
E: Event + Send + Sync + 'static,
{
async fn send(&self, event: E) -> Result<(), OnDemandError> {
self.send(event).map_err(|_e| OnDemandError::NetworkError)
}
}
#[async_trait]
pub trait EventHandler: Send + Sync {
async fn handle_event(&self, data: &[u8]) -> Result<(), OnDemandError>;
}
pub struct EventHandlerImpl<E, S>
where
E: Event + Send + Sync + 'static,
S: EventSenderTrait<E>,
{
sender: Arc<S>,
_marker: PhantomData<E>,
}
#[async_trait]
impl<E, S> EventHandler for EventHandlerImpl<E, S>
where
E: Event + Send + Sync + 'static,
S: EventSenderTrait<E>,
{
async fn handle_event(&self, data: &[u8]) -> Result<(), OnDemandError> {
match E::try_from_slice(data) {
Ok(event) => self.sender.send(event).await,
Err(_) => Err(OnDemandError::AnchorParseError),
}
}
}
impl<E, S> EventHandlerImpl<E, S>
where
E: Event + Send + Sync + 'static,
S: EventSenderTrait<E>,
{
pub fn new(sender: S) -> Self {
EventHandlerImpl {
sender: Arc::new(sender),
_marker: PhantomData,
}
}
}
pub struct PubSubEventClientBuilder {
program_id: Pubkey,
websocket_url: String,
other_pubkeys: Vec<Pubkey>,
max_retries: Option<i32>,
}
pub struct PubSubEventClientWithHandlers {
program_id: Pubkey,
websocket_url: String,
other_pubkeys: Vec<Pubkey>,
max_retries: Option<i32>,
cancellation_token: CancellationToken,
event_handlers: HashMap<[u8; 8], Box<dyn EventHandler>>,
}
impl PubSubEventClientBuilder {
pub fn new(program_id: Pubkey, websocket_url: String) -> Self {
Self {
program_id,
websocket_url: websocket_url
.replace("https://", "wss://")
.replace("http://", "ws://"),
other_pubkeys: Vec::new(),
max_retries: None,
}
}
pub fn mentions(mut self, pubkey: Pubkey) -> Self {
self.other_pubkeys.push(pubkey);
self
}
pub fn set_max_retries(mut self, max_retries: i32) -> Self {
self.max_retries = Some(max_retries);
self
}
pub fn add_event_handler<E: Event + Send + Sync + 'static, S: EventSenderTrait<E> + 'static>(
self,
sender: S,
) -> PubSubEventClientWithHandlers {
PubSubEventClientWithHandlers {
program_id: self.program_id,
websocket_url: self.websocket_url,
other_pubkeys: self.other_pubkeys,
max_retries: self.max_retries,
cancellation_token: CancellationToken::new(),
event_handlers: HashMap::from([(
E::DISCRIMINATOR,
Box::new(EventHandlerImpl::<E, S>::new(sender)) as Box<dyn EventHandler + 'static>,
)]),
}
}
}
impl PubSubEventClientWithHandlers {
pub fn mentions(mut self, pubkey: Pubkey) -> Self {
self.other_pubkeys.push(pubkey);
self
}
pub fn set_max_retries(mut self, max_retries: i32) -> Self {
self.max_retries = Some(max_retries);
self
}
pub fn add_event_handler<E: Event + Send + Sync + 'static, S: EventSenderTrait<E> + 'static>(
mut self,
sender: S,
) -> PubSubEventClientWithHandlers {
self.event_handlers.insert(
E::DISCRIMINATOR,
Box::new(EventHandlerImpl::<E, S>::new(sender)) as Box<dyn EventHandler + 'static>,
);
self
}
pub fn abort(self) {
self.cancellation_token.cancel();
}
pub async fn start(self) {
let cancellation_token = self.cancellation_token.clone();
tokio::select! {
_ = cancellation_token.cancelled() => {
log::info!("pubsub token cancelled");
},
_ = self.start_pubsub() => {
log::info!("start_pubsub returned unexpectedly");
}
}
}
async fn start_pubsub(&self) {
let mut retry_count = 0;
let mut delay = Duration::from_millis(500);
loop {
let pubsub_client = pubsub_client::PubsubClient::new(&self.websocket_url)
.await
.expect("Failed to create pubsub client");
let connection_result = pubsub_client
.logs_subscribe(
RpcTransactionLogsFilter::Mentions(
vec![
vec![self.program_id.to_string()],
self.other_pubkeys
.iter()
.map(|pubkey| pubkey.to_string())
.collect(),
]
.concat(),
),
RpcTransactionLogsConfig {
commitment: Some(CommitmentConfig::processed()),
},
)
.await;
match connection_result {
Ok((mut stream, _handler)) => {
retry_count = 0; delay = Duration::from_millis(500);
while let Some(event) = stream.next().await {
for line in event.value.logs {
if let Some(encoded_data) = line.strip_prefix("Program data: ") {
if let Ok(decoded_data) =
general_purpose::STANDARD.decode(encoded_data)
{
if decoded_data.len() <= 8 {
continue;
}
let (disc, event_data) = decoded_data.split_at(8);
if let Some(sender) = self.event_handlers.get(disc) {
let _ = sender.handle_event(event_data).await;
}
}
}
}
}
log::error!("[EVENT][WEBSOCKET] connection closed, attempting to reconnect...");
}
Err(e) => {
log::error!("[EVENT][WEBSOCKET] Failed to connect: {:?}", e);
match self.max_retries {
Some(max_retries) => {
if retry_count >= max_retries {
log::error!("[EVENT][WEBSOCKET] Maximum retry attempts reached, aborting...");
break;
}
tokio::time::sleep(delay).await; retry_count += 1;
delay = std::cmp::min(delay * 2, Duration::from_secs(5));
}
None => {
tokio::time::sleep(delay).await; retry_count += 1;
delay = std::cmp::min(delay * 2, Duration::from_secs(5));
continue;
}
}
}
}
if let Some(max_retries) = self.max_retries {
if retry_count >= max_retries {
log::error!("[EVENT][WEBSOCKET] Maximum retry attempts reached, aborting...");
break;
}
}
}
}
}