hyperlane_plugin_websocket/websocket/
impl.rs

1use crate::*;
2
3impl<'a> BroadcastType<'a> {
4    pub fn get_key(broadcast_type: BroadcastType) -> String {
5        match broadcast_type {
6            BroadcastType::PointToPoint(key1, key2) => {
7                let (first_key, second_key) = if key1 <= key2 {
8                    (key1, key2)
9                } else {
10                    (key2, key1)
11                };
12                format!("{}-{}-{}", POINT_TO_POINT_KEY, first_key, second_key)
13            }
14            BroadcastType::PointToGroup(key) => {
15                format!("{}-{}", POINT_TO_GROUP_KEY, key)
16            }
17        }
18    }
19}
20
21impl WebSocket {
22    pub fn new() -> Self {
23        Self {
24            broadcast_map: BroadcastMap::default(),
25        }
26    }
27
28    fn subscribe_unwrap_or_insert(
29        &self,
30        broadcast_type: BroadcastType,
31    ) -> BroadcastMapReceiver<Vec<u8>> {
32        let key: String = BroadcastType::get_key(broadcast_type);
33        self.broadcast_map.subscribe_unwrap_or_insert(&key)
34    }
35
36    fn point_to_point(&self, key1: &str, key2: &str) -> BroadcastMapReceiver<Vec<u8>> {
37        self.subscribe_unwrap_or_insert(BroadcastType::PointToPoint(key1, key2))
38    }
39
40    fn point_to_group(&self, key: &str) -> BroadcastMapReceiver<Vec<u8>> {
41        self.subscribe_unwrap_or_insert(BroadcastType::PointToGroup(key))
42    }
43
44    pub fn receiver_count<'a>(&self, broadcast_type: BroadcastType<'a>) -> OptionReceiverCount {
45        let key: String = BroadcastType::get_key(broadcast_type);
46        self.broadcast_map.receiver_count(&key)
47    }
48
49    pub async fn run<'a, F1, Fut1, F2, Fut2, F3, Fut3>(
50        &self,
51        ctx: &Context,
52        buffer_size: usize,
53        broadcast_type: BroadcastType<'a>,
54        callback: F1,
55        send_callback: F2,
56        client_closed_callback: F3,
57    ) where
58        F1: FuncWithoutPin<Fut1>,
59        Fut1: Future<Output = ()> + Send + 'static,
60        F2: FuncWithoutPin<Fut2>,
61        Fut2: Future<Output = ()> + Send + 'static,
62        F3: FuncWithoutPin<Fut3>,
63        Fut3: Future<Output = ()> + Send + 'static,
64    {
65        let mut receiver: Receiver<Vec<u8>> = match broadcast_type {
66            BroadcastType::PointToPoint(key1, key2) => self.point_to_point(key1, key2),
67            BroadcastType::PointToGroup(key) => self.point_to_group(key),
68        };
69        let key: String = BroadcastType::get_key(broadcast_type);
70        let result_handle = || async {
71            ctx.aborted().await;
72            ctx.closed().await;
73        };
74        loop {
75            tokio::select! {
76                request_res = ctx.websocket_request_from_stream(buffer_size) => {
77                    if request_res.is_err() {
78                        if let Err(RequestError::ClientClosedConnection) = request_res {
79                            client_closed_callback(ctx.clone()).await;
80                        }
81                        break;
82                    }
83                    callback(ctx.clone()).await;
84                    let body: ResponseBody = ctx.get_response_body().await;
85                    let send_res: BroadcastMapSendResult<_> = self.broadcast_map.send(&key, body);
86                    send_callback(ctx.clone()).await;
87                    if send_res.is_err() {
88                        break;
89                    }
90                },
91                msg_res = receiver.recv() => {
92                    if let Ok(msg) = msg_res {
93                        if ctx.send_response_body(msg).await.is_err() || ctx.flush().await.is_err() {
94                            break;
95                        }
96                    }
97               }
98            }
99        }
100        result_handle().await;
101    }
102}