hyperlane_plugin_websocket/websocket/
impl.rs

1use crate::*;
2
3impl BroadcastTypeTrait for String {}
4impl BroadcastTypeTrait for &str {}
5impl BroadcastTypeTrait for char {}
6impl BroadcastTypeTrait for bool {}
7impl BroadcastTypeTrait for i8 {}
8impl BroadcastTypeTrait for i16 {}
9impl BroadcastTypeTrait for i32 {}
10impl BroadcastTypeTrait for i64 {}
11impl BroadcastTypeTrait for i128 {}
12impl BroadcastTypeTrait for isize {}
13impl BroadcastTypeTrait for u8 {}
14impl BroadcastTypeTrait for u16 {}
15impl BroadcastTypeTrait for u32 {}
16impl BroadcastTypeTrait for u64 {}
17impl BroadcastTypeTrait for u128 {}
18impl BroadcastTypeTrait for usize {}
19impl BroadcastTypeTrait for f32 {}
20impl BroadcastTypeTrait for f64 {}
21impl BroadcastTypeTrait for IpAddr {}
22impl BroadcastTypeTrait for Ipv4Addr {}
23impl BroadcastTypeTrait for Ipv6Addr {}
24impl BroadcastTypeTrait for SocketAddr {}
25impl BroadcastTypeTrait for NonZeroU8 {}
26impl BroadcastTypeTrait for NonZeroU16 {}
27impl BroadcastTypeTrait for NonZeroU32 {}
28impl BroadcastTypeTrait for NonZeroU64 {}
29impl BroadcastTypeTrait for NonZeroU128 {}
30impl BroadcastTypeTrait for NonZeroUsize {}
31impl BroadcastTypeTrait for NonZeroI8 {}
32impl BroadcastTypeTrait for NonZeroI16 {}
33impl BroadcastTypeTrait for NonZeroI32 {}
34impl BroadcastTypeTrait for NonZeroI64 {}
35impl BroadcastTypeTrait for NonZeroI128 {}
36impl BroadcastTypeTrait for NonZeroIsize {}
37impl BroadcastTypeTrait for Infallible {}
38
39impl BroadcastTypeTrait for &String {}
40impl BroadcastTypeTrait for &&str {}
41impl BroadcastTypeTrait for &char {}
42impl BroadcastTypeTrait for &bool {}
43impl BroadcastTypeTrait for &i8 {}
44impl BroadcastTypeTrait for &i16 {}
45impl BroadcastTypeTrait for &i32 {}
46impl BroadcastTypeTrait for &i64 {}
47impl BroadcastTypeTrait for &i128 {}
48impl BroadcastTypeTrait for &isize {}
49impl BroadcastTypeTrait for &u8 {}
50impl BroadcastTypeTrait for &u16 {}
51impl BroadcastTypeTrait for &u32 {}
52impl BroadcastTypeTrait for &u64 {}
53impl BroadcastTypeTrait for &u128 {}
54impl BroadcastTypeTrait for &usize {}
55impl BroadcastTypeTrait for &f32 {}
56impl BroadcastTypeTrait for &f64 {}
57impl BroadcastTypeTrait for &IpAddr {}
58impl BroadcastTypeTrait for &Ipv4Addr {}
59impl BroadcastTypeTrait for &Ipv6Addr {}
60impl BroadcastTypeTrait for &SocketAddr {}
61impl BroadcastTypeTrait for &NonZeroU8 {}
62impl BroadcastTypeTrait for &NonZeroU16 {}
63impl BroadcastTypeTrait for &NonZeroU32 {}
64impl BroadcastTypeTrait for &NonZeroU64 {}
65impl BroadcastTypeTrait for &NonZeroU128 {}
66impl BroadcastTypeTrait for &NonZeroUsize {}
67impl BroadcastTypeTrait for &NonZeroI8 {}
68impl BroadcastTypeTrait for &NonZeroI16 {}
69impl BroadcastTypeTrait for &NonZeroI32 {}
70impl BroadcastTypeTrait for &NonZeroI64 {}
71impl BroadcastTypeTrait for &NonZeroI128 {}
72impl BroadcastTypeTrait for &NonZeroIsize {}
73impl BroadcastTypeTrait for &Infallible {}
74
75impl<B: BroadcastTypeTrait> BroadcastType<B> {
76    pub fn get_key(broadcast_type: BroadcastType<B>) -> String {
77        match broadcast_type {
78            BroadcastType::PointToPoint(key1, key2) => {
79                let (first_key, second_key) = if key1 <= key2 {
80                    (key1, key2)
81                } else {
82                    (key2, key1)
83                };
84                format!(
85                    "{}-{}-{}",
86                    POINT_TO_POINT_KEY,
87                    first_key.to_string(),
88                    second_key.to_string()
89                )
90            }
91            BroadcastType::PointToGroup(key) => {
92                format!("{}-{}", POINT_TO_GROUP_KEY, key.to_string())
93            }
94        }
95    }
96}
97
98impl WebSocket {
99    pub fn new() -> Self {
100        Self {
101            broadcast_map: BroadcastMap::default(),
102        }
103    }
104
105    fn subscribe_unwrap_or_insert<B: BroadcastTypeTrait>(
106        &self,
107        broadcast_type: BroadcastType<B>,
108        capacity: Capacity,
109    ) -> BroadcastMapReceiver<Vec<u8>> {
110        let key: String = BroadcastType::get_key(broadcast_type);
111        self.broadcast_map.subscribe_or_insert(&key, capacity)
112    }
113
114    fn point_to_point<B: BroadcastTypeTrait>(
115        &self,
116        key1: &B,
117        key2: &B,
118        capacity: Capacity,
119    ) -> BroadcastMapReceiver<Vec<u8>> {
120        self.subscribe_unwrap_or_insert(
121            BroadcastType::PointToPoint(key1.clone(), key2.clone()),
122            capacity,
123        )
124    }
125
126    fn point_to_group<B: BroadcastTypeTrait>(
127        &self,
128        key: &B,
129        capacity: Capacity,
130    ) -> BroadcastMapReceiver<Vec<u8>> {
131        self.subscribe_unwrap_or_insert(BroadcastType::PointToGroup(key.clone()), capacity)
132    }
133
134    pub fn receiver_count<'a, B: BroadcastTypeTrait>(
135        &self,
136        broadcast_type: BroadcastType<B>,
137    ) -> ReceiverCount {
138        let key: String = BroadcastType::get_key(broadcast_type);
139        self.broadcast_map.receiver_count(&key).unwrap_or(0)
140    }
141
142    pub fn receiver_count_after_increment<B: BroadcastTypeTrait>(
143        &self,
144        broadcast_type: BroadcastType<B>,
145    ) -> ReceiverCount {
146        let count: ReceiverCount = self.receiver_count(broadcast_type);
147        count.max(0).min(ReceiverCount::MAX - 1) + 1
148    }
149
150    pub fn receiver_count_after_decrement<B: BroadcastTypeTrait>(
151        &self,
152        broadcast_type: BroadcastType<B>,
153    ) -> ReceiverCount {
154        let count: ReceiverCount = self.receiver_count(broadcast_type);
155        count.max(1).min(ReceiverCount::MAX) - 1
156    }
157
158    pub fn send<T, B>(
159        &self,
160        broadcast_type: BroadcastType<B>,
161        data: T,
162    ) -> BroadcastMapSendResult<Vec<u8>>
163    where
164        T: Into<Vec<u8>>,
165        B: BroadcastTypeTrait,
166    {
167        let key: String = BroadcastType::get_key(broadcast_type);
168        self.broadcast_map.send(&key, data.into())
169    }
170
171    pub async fn run<'a, F1, Fut1, F2, Fut2, F3, Fut3, B>(
172        &self,
173        ctx: Context,
174        buffer_size: usize,
175        capacity: Capacity,
176        broadcast_type: BroadcastType<B>,
177        request_hook: F1,
178        sended_hook: F2,
179        closed_hook: F3,
180    ) where
181        F1: FnSendSyncStatic<Fut1>,
182        Fut1: FutureSendStatic<()>,
183        F2: FnSendSyncStatic<Fut2>,
184        Fut2: FutureSendStatic<()>,
185        F3: FnSendSyncStatic<Fut3>,
186        Fut3: FutureSendStatic<()>,
187        B: BroadcastTypeTrait,
188    {
189        let mut receiver: Receiver<Vec<u8>> = match &broadcast_type {
190            BroadcastType::PointToPoint(key1, key2) => self.point_to_point(key1, key2, capacity),
191            BroadcastType::PointToGroup(key) => self.point_to_group(key, capacity),
192        };
193        let key: String = BroadcastType::get_key(broadcast_type);
194        let result_handle = || async {
195            ctx.aborted().await;
196            ctx.closed().await;
197        };
198        loop {
199            tokio::select! {
200                request_res = ctx.ws_from_stream(buffer_size) => {
201                    let mut need_break = false;
202                    if request_res.is_ok() {
203                        request_hook(ctx.clone()).await;
204                    } else {
205                        need_break = true;
206                        closed_hook(ctx.clone()).await;
207                    }
208                    let body: ResponseBody = ctx.get_response_body().await;
209                    let is_err: bool = self.broadcast_map.send(&key, body).is_err();
210                    sended_hook(ctx.clone()).await;
211                    if need_break || is_err {
212                        break;
213                    }
214                },
215                msg_res = receiver.recv() => {
216                    if let Ok(msg) = msg_res {
217                        if ctx.set_response_body(msg).await.send_body().await.is_ok() {
218                            continue;
219                        }
220                    }
221                    break;
222               }
223            }
224        }
225        result_handle().await;
226    }
227}