use crate::common::{
payload::{BnWsStreamMethod, BnWsStreamPayload},
response::{BnWsStreamData, BnWsStreamResponse},
};
use ezsockets::{Bytes, Client, ClientConfig, ClientExt, Error, Utf8Bytes};
use std::collections::HashMap;
use tokio::sync::{mpsc, oneshot};
use ulid::Ulid;
use xapi_shared::ws::{api::SharedWsApiTrait, error::SharedWsError, stream::SharedWsStreamTrait};
pub struct BnWsStream {
client: Client<Self>,
on_connect_tx: Option<oneshot::Sender<()>>,
oneshot_tx_map: HashMap<String, oneshot::Sender<Result<BnWsStreamResponse, SharedWsError>>>,
stream_tx_map: HashMap<String, mpsc::Sender<Result<BnWsStreamData, SharedWsError>>>,
}
pub enum BnWsStreamCall {
SubscribeStream {
streams: Vec<(String, mpsc::Sender<Result<BnWsStreamData, SharedWsError>>)>,
tx: oneshot::Sender<Result<BnWsStreamResponse, SharedWsError>>,
},
}
#[async_trait::async_trait]
impl ClientExt for BnWsStream {
type Call = BnWsStreamCall;
async fn on_text(&mut self, text: Utf8Bytes) -> Result<(), Error> {
let msg = text.to_string();
if let Some(result) = self.recv_stream_resp(&msg).await {
return result.map_err(|err| err.into());
}
if let Some(result) = self.recv_oneshot_resp(&msg) {
return result.map_err(|err| err.into());
}
tracing::error!(?msg, "unhandled bn ws message");
Err(SharedWsError::AppError("unhandled bn ws message".to_string()).into())
}
async fn on_binary(&mut self, _bytes: Bytes) -> Result<(), Error> {
unimplemented!()
}
async fn on_call(&mut self, call: Self::Call) -> Result<(), Error> {
match call {
BnWsStreamCall::SubscribeStream { streams, tx } => {
self.subscribe_streams(streams, tx)?
}
}
Ok(())
}
async fn on_connect(&mut self) -> Result<(), Error> {
if let Some(tx) = self.on_connect_tx.take() {
tx.send(())
.inspect_err(|err| {
tracing::error!(?err, "failed to send on_connect signal");
})
.map_err(|_| {
SharedWsError::ChannelClosedError("first on connect channel closed".to_string())
})?;
}
Ok(())
}
}
impl SharedWsApiTrait<String, BnWsStreamPayload, BnWsStreamResponse> for BnWsStream {
fn get_client(&self) -> &Client<Self> {
&self.client
}
fn get_oneshot_tx_map(
&mut self,
) -> &mut HashMap<String, oneshot::Sender<Result<BnWsStreamResponse, SharedWsError>>> {
&mut self.oneshot_tx_map
}
}
#[async_trait::async_trait]
impl SharedWsStreamTrait<String, BnWsStreamData> for BnWsStream {
fn get_stream_tx_map(
&mut self,
) -> &mut HashMap<String, mpsc::Sender<Result<BnWsStreamData, SharedWsError>>> {
&mut self.stream_tx_map
}
}
impl BnWsStream {
pub async fn connect(config: ClientConfig) -> Client<Self> {
let (on_connect_tx, on_connect_rx) = oneshot::channel();
let (client, future) = ezsockets::connect(
|client| Self {
client,
on_connect_tx: Some(on_connect_tx),
oneshot_tx_map: Default::default(),
stream_tx_map: Default::default(),
},
config,
)
.await;
tokio::spawn(async move {
future.await.inspect_err(|err| {
tracing::error!(?err, "bn ws client connection error");
})
});
_ = on_connect_rx.await;
client
}
fn subscribe_streams(
&mut self,
streams: Vec<(String, mpsc::Sender<Result<BnWsStreamData, SharedWsError>>)>,
tx: oneshot::Sender<Result<BnWsStreamResponse, SharedWsError>>,
) -> Result<(), SharedWsError> {
if streams.is_empty() {
tracing::warn!("no streams to subscribe");
return Ok(());
}
for (stream, _) in &streams {
if self.stream_tx_map.contains_key(stream) {
tracing::error!(stream, "duplicated stream in ws subscribe stream request");
return Err(SharedWsError::InvalidIdError(stream.clone()));
}
}
let id = Ulid::new().to_string();
let payload = BnWsStreamPayload {
id,
method: BnWsStreamMethod::Subscribe,
params: Some(serde_json::Value::Array(
streams
.iter()
.map(|(stream, _)| serde_json::Value::String(stream.clone()))
.collect::<Vec<_>>(),
)),
};
for (stream, tx) in streams {
self.stream_tx_map.insert(stream, tx);
}
self.send_oneshot(payload, tx)
}
}