hyperlane_plugin_websocket/websocket/
impl.rs1use crate::*;
2
3impl<'a> BroadcastType<'a> {
4 pub fn get_key(broadcast_type: BroadcastType) -> String {
5 match broadcast_type {
6 BroadcastType::PointToPoint(key1, key2) => {
7 let (first_key, second_key) = if key1 <= key2 {
8 (key1, key2)
9 } else {
10 (key2, key1)
11 };
12 format!("{}-{}-{}", POINT_TO_POINT_KEY, first_key, second_key)
13 }
14 BroadcastType::PointToGroup(key) => {
15 format!("{}-{}", POINT_TO_GROUP_KEY, key)
16 }
17 }
18 }
19}
20
21impl WebSocket {
22 pub fn new() -> Self {
23 Self {
24 broadcast_map: BroadcastMap::default(),
25 }
26 }
27
28 fn subscribe_unwrap_or_insert(
29 &self,
30 broadcast_type: BroadcastType,
31 ) -> BroadcastMapReceiver<Vec<u8>> {
32 let key: String = BroadcastType::get_key(broadcast_type);
33 self.broadcast_map.subscribe_unwrap_or_insert(&key)
34 }
35
36 fn point_to_point(&self, key1: &str, key2: &str) -> BroadcastMapReceiver<Vec<u8>> {
37 self.subscribe_unwrap_or_insert(BroadcastType::PointToPoint(key1, key2))
38 }
39
40 fn point_to_group(&self, key: &str) -> BroadcastMapReceiver<Vec<u8>> {
41 self.subscribe_unwrap_or_insert(BroadcastType::PointToGroup(key))
42 }
43
44 pub fn receiver_count<'a>(&self, broadcast_type: BroadcastType<'a>) -> OptionReceiverCount {
45 let key: String = BroadcastType::get_key(broadcast_type);
46 self.broadcast_map.receiver_count(&key)
47 }
48
49 pub async fn run<'a, F1, Fut1, F2, Fut2, F3, Fut3>(
50 &self,
51 ctx: &Context,
52 buffer_size: usize,
53 broadcast_type: BroadcastType<'a>,
54 callback: F1,
55 send_callback: F2,
56 client_closed_callback: F3,
57 ) where
58 F1: FuncWithoutPin<Fut1>,
59 Fut1: Future<Output = ()> + Send + 'static,
60 F2: FuncWithoutPin<Fut2>,
61 Fut2: Future<Output = ()> + Send + 'static,
62 F3: FuncWithoutPin<Fut3>,
63 Fut3: Future<Output = ()> + Send + 'static,
64 {
65 let mut receiver: Receiver<Vec<u8>> = match broadcast_type {
66 BroadcastType::PointToPoint(key1, key2) => self.point_to_point(key1, key2),
67 BroadcastType::PointToGroup(key) => self.point_to_group(key),
68 };
69 let key: String = BroadcastType::get_key(broadcast_type);
70 let result_handle = || async {
71 ctx.aborted().await;
72 ctx.closed().await;
73 };
74 loop {
75 tokio::select! {
76 request_res = ctx.websocket_request_from_stream(buffer_size) => {
77 if request_res.is_err() {
78 if let Err(RequestError::ClientClosedConnection) = request_res {
79 client_closed_callback(ctx.clone()).await;
80 }
81 break;
82 }
83 callback(ctx.clone()).await;
84 let body: ResponseBody = ctx.get_response_body().await;
85 let send_res: BroadcastMapSendResult<_> = self.broadcast_map.send(&key, body);
86 send_callback(ctx.clone()).await;
87 if send_res.is_err() {
88 break;
89 }
90 },
91 msg_res = receiver.recv() => {
92 if let Ok(msg) = msg_res {
93 if ctx.send_response_body(msg).await.is_err() || ctx.flush().await.is_err() {
94 break;
95 }
96 }
97 }
98 }
99 }
100 result_handle().await;
101 }
102}