hyperlane_plugin_websocket/websocket/
impl.rs

1use crate::*;
2
3impl<'a> BroadcastType<'a> {
4    pub fn get_key(broadcast_type: BroadcastType) -> String {
5        match broadcast_type {
6            BroadcastType::PointToPoint(key1, key2) => {
7                let (first_key, second_key) = if key1 <= key2 {
8                    (key1, key2)
9                } else {
10                    (key2, key1)
11                };
12                format!("{}-{}-{}", POINT_TO_POINT_KEY, first_key, second_key)
13            }
14            BroadcastType::PointToGroup(key) => {
15                format!("{}-{}", POINT_TO_GROUP_KEY, key)
16            }
17        }
18    }
19}
20
21impl WebSocket {
22    pub fn new() -> Self {
23        Self {
24            broadcast_map: BroadcastMap::default(),
25        }
26    }
27
28    fn subscribe_unwrap_or_insert(
29        &self,
30        broadcast_type: BroadcastType,
31    ) -> BroadcastMapReceiver<Vec<u8>> {
32        let key: String = BroadcastType::get_key(broadcast_type);
33        self.get_broadcast_map().subscribe_unwrap_or_insert(&key)
34    }
35
36    fn point_to_point(&self, key1: &str, key2: &str) -> BroadcastMapReceiver<Vec<u8>> {
37        self.subscribe_unwrap_or_insert(BroadcastType::PointToPoint(key1, key2))
38    }
39
40    fn point_to_group(&self, key: &str) -> BroadcastMapReceiver<Vec<u8>> {
41        self.subscribe_unwrap_or_insert(BroadcastType::PointToGroup(key))
42    }
43
44    pub async fn run<'a, F1, Fut1, F2, Fut2>(
45        &self,
46        ctx: &Context,
47        buffer_size: usize,
48        broadcast_type: BroadcastType<'a>,
49        callback: F1,
50        send_callback: F2,
51    ) where
52        F1: FuncWithoutPin<Fut1>,
53        Fut1: Future<Output = ()> + Send + 'static,
54        F2: FuncWithoutPin<Fut2>,
55        Fut2: Future<Output = ()> + Send + 'static,
56    {
57        let mut receiver: Receiver<Vec<u8>> = match broadcast_type {
58            BroadcastType::PointToPoint(key1, key2) => self.point_to_point(key1, key2),
59            BroadcastType::PointToGroup(key) => self.point_to_group(key),
60        };
61        let key: String = BroadcastType::get_key(broadcast_type);
62        loop {
63            tokio::select! {
64                request_res = ctx.websocket_request_from_stream(buffer_size) => {
65                    if request_res.is_err() {
66                        break;
67                    }
68                    callback(ctx.clone()).await;
69                    let body: ResponseBody = ctx.get_response_body().await;
70                    let send_res: BroadcastMapSendResult<_> = self.get_broadcast_map().send(&key, body);
71                    send_callback(ctx.clone()).await;
72                    if send_res.is_err() {
73                        break;
74                    }
75                },
76                msg_res = receiver.recv() => {
77                    if let Ok(msg) = msg_res {
78                        if ctx.send_response_body(msg).await.is_err() || ctx.flush().await.is_err() {
79                            break;
80                        }
81                    }
82               }
83            }
84        }
85    }
86}