use crate::{
common::{
endpoint::BnEndpoint,
payload::BnWsApiPayload,
ratelimiter::BnRatelimiter,
response::BnWsApiRespType,
signer::BnSigner,
ws::{
api::{BnWsApi, BnWsApiCall},
stream::{BnWsStream, BnWsStreamCall},
},
},
data::enums::ratelimit::BnRateLimitType,
};
use serde::{Serialize, de::DeserializeOwned};
use std::{num::NonZeroU32, sync::Arc};
use tokio::sync::{Mutex, OnceCell, mpsc, oneshot};
use typed_builder::TypedBuilder;
use ulid::Ulid;
use xapi_shared::{
ratelimiter::SharedRatelimiterTrait, rest::SharedRestClientTrait, signer::SharedSignerTrait,
ws::error::SharedWsError,
};
#[derive(TypedBuilder)]
pub struct BnExecutor {
endpoint: BnEndpoint,
#[builder(default = reqwest::Client::new())]
rest_client: reqwest::Client,
#[builder(default = None, setter(strip_option))]
signer: Option<BnSigner>,
#[builder(default = Arc::new(BnRatelimiter::default()))]
ratelimiter: Arc<BnRatelimiter>,
#[builder(default)]
ws_api: OnceCell<ezsockets::Client<BnWsApi>>,
#[builder(default)]
streams: Mutex<Vec<ezsockets::Client<BnWsStream>>>,
}
impl SharedRestClientTrait<BnRateLimitType> for BnExecutor {
fn get_client(&self) -> &reqwest::Client {
&self.rest_client
}
fn get_signer(&self) -> &dyn SharedSignerTrait {
if let Some(signer) = &self.signer {
signer
} else {
tracing::error!("signer is not set for BnExecutor");
panic!("signer is not set for BnExecutor");
}
}
fn get_ratelimiter(&self) -> Arc<dyn SharedRatelimiterTrait<BnRateLimitType> + Sync + Send> {
self.ratelimiter.clone()
}
}
impl BnExecutor {
pub fn get_endpoint(&self) -> &BnEndpoint {
&self.endpoint
}
pub async fn call_ws_api<ReqType: Serialize, ResType: DeserializeOwned>(
&self,
limits: &[(BnRateLimitType, NonZeroU32)],
signed: bool,
method: &str,
params: Option<ReqType>,
) -> BnWsApiRespType<ResType> {
let params = params
.map(|p| {
serde_json::to_value(p)
.inspect_err(|err| tracing::error!(?err, "failed to serialize ws api params"))
.map_err(|err| SharedWsError::SerdeError(err.to_string()))
})
.transpose()?;
let params = match signed {
true => match &self.signer {
Some(signer) => Some(signer.sign_ws_payload(params)?),
None => {
tracing::error!("signer is not set for BnExecutor");
return Err(SharedWsError::AppError("signer is not set".to_string()));
}
},
false => params,
};
for (rate_limit_type, limit) in limits.iter() {
self.get_ratelimiter()
.limit_on(rate_limit_type, *limit)
.await?;
}
let api = self.get_ws_api().await;
let (tx, rx) = oneshot::channel();
api.call(BnWsApiCall::SendApi {
payload: BnWsApiPayload {
id: Ulid::new().to_string(),
method: method.to_string(),
params,
},
tx,
})
.inspect_err(|err| tracing::error!(?err, "failed to send ws api request"))
.map_err(|err| SharedWsError::ChannelClosedError(err.to_string()))?;
let resp = rx
.await
.inspect_err(|err| {
tracing::error!(?err, "failed to receive ws api response");
})
.map_err(|err| SharedWsError::ChannelClosedError(err.to_string()))??;
serde_json::from_value(resp.result)
.inspect_err(|err| tracing::error!(?err, "failed to parse ws api response result"))
.map_err(|err| SharedWsError::SerdeError(err.to_string()))
}
pub async fn subscribe_stream<T: DeserializeOwned + Send + 'static>(
&self,
stream: String,
) -> Result<mpsc::Receiver<Result<T, SharedWsError>>, SharedWsError> {
let ws_stream_base_url = self.endpoint.get_ws_stream_base_url().clone();
let client = BnWsStream::connect(ezsockets::ClientConfig::new(ws_stream_base_url)).await;
let (raw_tx, mut raw_rx) = mpsc::channel(128);
let (oneshot_tx, oneshot_rx) = oneshot::channel();
let message = BnWsStreamCall::SubscribeStream {
streams: vec![(stream, raw_tx)],
tx: oneshot_tx,
};
client
.call(message)
.inspect_err(|err| tracing::error!(?err, "failed to subscribe to stream"))
.map_err(|err| SharedWsError::AppError(err.to_string()))?;
self.streams.lock().await.push(client);
oneshot_rx
.await
.map_err(|err| SharedWsError::ChannelClosedError(err.to_string()))??;
let (tx, rx) = mpsc::channel(128);
tokio::spawn(async move {
while let Some(result) = raw_rx.recv().await {
let msg = match result {
Ok(resp) => match serde_json::from_value::<T>(resp.data) {
Ok(data) => Ok(data),
Err(err) => {
tracing::error!(?err, "failed to parse message");
Err(SharedWsError::SerdeError(err.to_string()))
}
},
Err(err) => Err(err),
};
if let Err(err) = tx.send(msg).await {
tracing::error!(?err, "failed to send message");
}
}
});
Ok(rx)
}
async fn get_ws_api(&self) -> &ezsockets::Client<BnWsApi> {
self.ws_api
.get_or_init(async || {
let ws_api_base_url = self.endpoint.get_ws_api_base_url().clone();
BnWsApi::connect(ezsockets::ClientConfig::new(ws_api_base_url)).await
})
.await
}
}