1use futures_core::Stream;
2use futures_util::{SinkExt as _, StreamExt};
3use std::pin::Pin;
4use tokio_stream::wrappers::UnboundedReceiverStream;
5use tonic::transport::{Endpoint, Uri};
6
7use crate::error::{M10Error, M10Result};
8use m10_protos::prost::Message;
9use m10_protos::sdk;
10use tokio::sync::mpsc;
11use tokio::sync::mpsc::UnboundedSender;
12use tokio_tungstenite::connect_async;
13use tokio_tungstenite::tungstenite::Message as WSMessage;
14
15#[derive(Clone)]
16pub struct WSClient {
17 endpoint: Endpoint,
18}
19
20impl WSClient {
21 pub fn new(endpoint: Endpoint) -> Self {
22 Self { endpoint }
23 }
24
25 pub async fn observe_with_request<T, F>(
26 &self,
27 ep: &str,
28 req: sdk::RequestEnvelope,
29 f: F,
30 ) -> M10Result<Pin<Box<dyn Stream<Item = M10Result<T>> + Send + Sync + 'static>>>
31 where
32 F: FnMut(Vec<u8>) -> M10Result<T> + Send + Sync + 'static,
33 {
34 let (msg_tx, msg_rx) = mpsc::unbounded_channel();
35
36 tokio::spawn({
37 let msg_tx = msg_tx.clone();
38 let base_url = self.endpoint.uri().clone();
39 let endpoint = ep.to_string().clone();
40
41 async move {
42 if let Err(err) = observe_msgs(msg_tx, req, base_url, endpoint).await {
43 eprintln!("Failed to spawn WebSocket client thread: {}", err);
44 }
45 }
46 });
47
48 Ok(Box::pin(UnboundedReceiverStream::new(msg_rx).map(f)))
49 }
50}
51
52async fn observe_msgs(
53 msg_tx: UnboundedSender<Vec<u8>>,
54 req: sdk::RequestEnvelope,
55 base_url: Uri,
56 endpoint: String,
57) -> M10Result<()> {
58 let (mut ws, _) = connect_async(format!("{}ledger/ws/observe/{}", base_url, endpoint))
59 .await
60 .map_err(M10Error::from)?;
61
62 let mut req_body = vec![];
63 req.encode(&mut req_body).expect("Failed to encode");
64
65 ws.send(req_body.into()).await.map_err(M10Error::from)?;
66
67 while let Some(msg) = ws.next().await {
68 match msg {
69 Ok(WSMessage::Binary(bin)) => {
70 if msg_tx.send(bin).is_err() {
71 break;
72 }
73 }
74
75 Ok(WSMessage::Ping(_)) => {
76 ws.send(WSMessage::Pong(Vec::new()))
77 .await
78 .map_err(M10Error::from)?;
79 }
80
81 Err(e) => {
82 eprintln!(
83 "Error during listening messages from the WebSocket connection: {:?}",
84 e
85 );
86 return Err(M10Error::WsError(e));
87 }
88
89 _ => {}
90 }
91 }
92
93 Ok(())
94}