hyperlane_plugin_websocket/websocket/
impl.rs

1use crate::*;
2
3impl BroadcastTypeTrait for String {}
4impl BroadcastTypeTrait for &str {}
5impl BroadcastTypeTrait for char {}
6impl BroadcastTypeTrait for bool {}
7impl BroadcastTypeTrait for i8 {}
8impl BroadcastTypeTrait for i16 {}
9impl BroadcastTypeTrait for i32 {}
10impl BroadcastTypeTrait for i64 {}
11impl BroadcastTypeTrait for i128 {}
12impl BroadcastTypeTrait for isize {}
13impl BroadcastTypeTrait for u8 {}
14impl BroadcastTypeTrait for u16 {}
15impl BroadcastTypeTrait for u32 {}
16impl BroadcastTypeTrait for u64 {}
17impl BroadcastTypeTrait for u128 {}
18impl BroadcastTypeTrait for usize {}
19impl BroadcastTypeTrait for f32 {}
20impl BroadcastTypeTrait for f64 {}
21impl BroadcastTypeTrait for IpAddr {}
22impl BroadcastTypeTrait for Ipv4Addr {}
23impl BroadcastTypeTrait for Ipv6Addr {}
24impl BroadcastTypeTrait for SocketAddr {}
25impl BroadcastTypeTrait for NonZeroU8 {}
26impl BroadcastTypeTrait for NonZeroU16 {}
27impl BroadcastTypeTrait for NonZeroU32 {}
28impl BroadcastTypeTrait for NonZeroU64 {}
29impl BroadcastTypeTrait for NonZeroU128 {}
30impl BroadcastTypeTrait for NonZeroUsize {}
31impl BroadcastTypeTrait for NonZeroI8 {}
32impl BroadcastTypeTrait for NonZeroI16 {}
33impl BroadcastTypeTrait for NonZeroI32 {}
34impl BroadcastTypeTrait for NonZeroI64 {}
35impl BroadcastTypeTrait for NonZeroI128 {}
36impl BroadcastTypeTrait for NonZeroIsize {}
37impl BroadcastTypeTrait for Infallible {}
38
39impl BroadcastTypeTrait for &String {}
40impl BroadcastTypeTrait for &&str {}
41impl BroadcastTypeTrait for &char {}
42impl BroadcastTypeTrait for &bool {}
43impl BroadcastTypeTrait for &i8 {}
44impl BroadcastTypeTrait for &i16 {}
45impl BroadcastTypeTrait for &i32 {}
46impl BroadcastTypeTrait for &i64 {}
47impl BroadcastTypeTrait for &i128 {}
48impl BroadcastTypeTrait for &isize {}
49impl BroadcastTypeTrait for &u8 {}
50impl BroadcastTypeTrait for &u16 {}
51impl BroadcastTypeTrait for &u32 {}
52impl BroadcastTypeTrait for &u64 {}
53impl BroadcastTypeTrait for &u128 {}
54impl BroadcastTypeTrait for &usize {}
55impl BroadcastTypeTrait for &f32 {}
56impl BroadcastTypeTrait for &f64 {}
57impl BroadcastTypeTrait for &IpAddr {}
58impl BroadcastTypeTrait for &Ipv4Addr {}
59impl BroadcastTypeTrait for &Ipv6Addr {}
60impl BroadcastTypeTrait for &SocketAddr {}
61impl BroadcastTypeTrait for &NonZeroU8 {}
62impl BroadcastTypeTrait for &NonZeroU16 {}
63impl BroadcastTypeTrait for &NonZeroU32 {}
64impl BroadcastTypeTrait for &NonZeroU64 {}
65impl BroadcastTypeTrait for &NonZeroU128 {}
66impl BroadcastTypeTrait for &NonZeroUsize {}
67impl BroadcastTypeTrait for &NonZeroI8 {}
68impl BroadcastTypeTrait for &NonZeroI16 {}
69impl BroadcastTypeTrait for &NonZeroI32 {}
70impl BroadcastTypeTrait for &NonZeroI64 {}
71impl BroadcastTypeTrait for &NonZeroI128 {}
72impl BroadcastTypeTrait for &NonZeroIsize {}
73impl BroadcastTypeTrait for &Infallible {}
74
75impl<B: BroadcastTypeTrait> Default for BroadcastType<B> {
76    fn default() -> Self {
77        BroadcastType::Unknown
78    }
79}
80
81impl<B: BroadcastTypeTrait> BroadcastType<B> {
82    pub fn get_key(broadcast_type: BroadcastType<B>) -> String {
83        match broadcast_type {
84            BroadcastType::PointToPoint(key1, key2) => {
85                let (first_key, second_key) = if key1 <= key2 {
86                    (key1, key2)
87                } else {
88                    (key2, key1)
89                };
90                format!(
91                    "{}-{}-{}",
92                    POINT_TO_POINT_KEY,
93                    first_key.to_string(),
94                    second_key.to_string()
95                )
96            }
97            BroadcastType::PointToGroup(key) => {
98                format!("{}-{}", POINT_TO_GROUP_KEY, key.to_string())
99            }
100            BroadcastType::Unknown => String::new(),
101        }
102    }
103}
104
105impl<B: BroadcastTypeTrait> Default for WebSocketConfig<B> {
106    fn default() -> Self {
107        let default_hook: ArcFnPinBoxSendSync = Arc::new(move |_| Box::pin(async {}));
108        Self {
109            context: Context::default(),
110            buffer_size: DEFAULT_BUFFER_SIZE,
111            capacity: DEFAULT_BROADCAST_SENDER_CAPACITY,
112            broadcast_type: BroadcastType::default(),
113            request_hook: default_hook.clone(),
114            sended_hook: default_hook.clone(),
115            closed_hook: default_hook,
116        }
117    }
118}
119
120impl<B: BroadcastTypeTrait> WebSocketConfig<B> {
121    pub fn new() -> Self {
122        Self::default()
123    }
124
125    pub fn set_buffer_size(mut self, buffer_size: usize) -> Self {
126        self.buffer_size = buffer_size;
127        self
128    }
129
130    pub fn set_capacity(mut self, capacity: Capacity) -> Self {
131        self.capacity = capacity;
132        self
133    }
134
135    pub fn set_context(mut self, context: Context) -> Self {
136        self.context = context;
137        self
138    }
139
140    pub fn set_broadcast_type(mut self, broadcast_type: BroadcastType<B>) -> Self {
141        self.broadcast_type = broadcast_type;
142        self
143    }
144
145    pub fn set_request_hook<F, Fut>(mut self, hook: F) -> Self
146    where
147        F: Fn(Context) -> Fut + Send + Sync + 'static,
148        Fut: Future<Output = ()> + Send + 'static,
149    {
150        self.request_hook = Arc::new(move |ctx| Box::pin(hook(ctx)));
151        self
152    }
153
154    pub fn set_sended_hook<F, Fut>(mut self, hook: F) -> Self
155    where
156        F: Fn(Context) -> Fut + Send + Sync + 'static,
157        Fut: Future<Output = ()> + Send + 'static,
158    {
159        self.sended_hook = Arc::new(move |ctx| Box::pin(hook(ctx)));
160        self
161    }
162
163    pub fn set_closed_hook<F, Fut>(mut self, hook: F) -> Self
164    where
165        F: Fn(Context) -> Fut + Send + Sync + 'static,
166        Fut: Future<Output = ()> + Send + 'static,
167    {
168        self.closed_hook = Arc::new(move |ctx| Box::pin(hook(ctx)));
169        self
170    }
171
172    pub fn get_context(&self) -> &Context {
173        &self.context
174    }
175
176    pub fn get_buffer_size(&self) -> usize {
177        self.buffer_size
178    }
179
180    pub fn get_capacity(&self) -> Capacity {
181        self.capacity
182    }
183
184    pub fn get_broadcast_type(&self) -> &BroadcastType<B> {
185        &self.broadcast_type
186    }
187
188    pub fn get_request_hook(&self) -> &ArcFnPinBoxSendSync {
189        &self.request_hook
190    }
191
192    pub fn get_sended_hook(&self) -> &ArcFnPinBoxSendSync {
193        &self.sended_hook
194    }
195
196    pub fn get_closed_hook(&self) -> &ArcFnPinBoxSendSync {
197        &self.closed_hook
198    }
199}
200
201impl WebSocket {
202    pub fn new() -> Self {
203        Self {
204            broadcast_map: BroadcastMap::default(),
205        }
206    }
207
208    fn subscribe_unwrap_or_insert<B: BroadcastTypeTrait>(
209        &self,
210        broadcast_type: BroadcastType<B>,
211        capacity: Capacity,
212    ) -> BroadcastMapReceiver<Vec<u8>> {
213        let key: String = BroadcastType::get_key(broadcast_type);
214        self.broadcast_map.subscribe_or_insert(&key, capacity)
215    }
216
217    fn point_to_point<B: BroadcastTypeTrait>(
218        &self,
219        key1: &B,
220        key2: &B,
221        capacity: Capacity,
222    ) -> BroadcastMapReceiver<Vec<u8>> {
223        self.subscribe_unwrap_or_insert(
224            BroadcastType::PointToPoint(key1.clone(), key2.clone()),
225            capacity,
226        )
227    }
228
229    fn point_to_group<B: BroadcastTypeTrait>(
230        &self,
231        key: &B,
232        capacity: Capacity,
233    ) -> BroadcastMapReceiver<Vec<u8>> {
234        self.subscribe_unwrap_or_insert(BroadcastType::PointToGroup(key.clone()), capacity)
235    }
236
237    pub fn receiver_count<'a, B: BroadcastTypeTrait>(
238        &self,
239        broadcast_type: BroadcastType<B>,
240    ) -> ReceiverCount {
241        let key: String = BroadcastType::get_key(broadcast_type);
242        self.broadcast_map.receiver_count(&key).unwrap_or(0)
243    }
244
245    pub fn receiver_count_after_increment<B: BroadcastTypeTrait>(
246        &self,
247        broadcast_type: BroadcastType<B>,
248    ) -> ReceiverCount {
249        let count: ReceiverCount = self.receiver_count(broadcast_type);
250        count.max(0).min(ReceiverCount::MAX - 1) + 1
251    }
252
253    pub fn receiver_count_after_decrement<B: BroadcastTypeTrait>(
254        &self,
255        broadcast_type: BroadcastType<B>,
256    ) -> ReceiverCount {
257        let count: ReceiverCount = self.receiver_count(broadcast_type);
258        count.max(1).min(ReceiverCount::MAX) - 1
259    }
260
261    pub fn send<T, B>(
262        &self,
263        broadcast_type: BroadcastType<B>,
264        data: T,
265    ) -> BroadcastMapSendResult<Vec<u8>>
266    where
267        T: Into<Vec<u8>>,
268        B: BroadcastTypeTrait,
269    {
270        let key: String = BroadcastType::get_key(broadcast_type);
271        self.broadcast_map.send(&key, data.into())
272    }
273
274    pub async fn run<B: BroadcastTypeTrait>(&self, config: WebSocketConfig<B>) {
275        let ctx: Context = config.get_context().clone();
276        if ctx.to_string() == Context::default().to_string() {
277            panic!("Context must be set");
278        }
279        let buffer_size: usize = config.get_buffer_size();
280        let capacity: Capacity = config.get_capacity();
281        let broadcast_type: BroadcastType<B> = config.get_broadcast_type().clone();
282        let mut receiver: Receiver<Vec<u8>> = match &broadcast_type {
283            BroadcastType::PointToPoint(key1, key2) => self.point_to_point(key1, key2, capacity),
284            BroadcastType::PointToGroup(key) => self.point_to_group(key, capacity),
285            BroadcastType::Unknown => panic!("BroadcastType must be PointToPoint or PointToGroup"),
286        };
287        let key: String = BroadcastType::get_key(broadcast_type);
288        let result_handle = || async {
289            ctx.aborted().await;
290            ctx.closed().await;
291        };
292        loop {
293            tokio::select! {
294                request_res = ctx.ws_from_stream(buffer_size) => {
295                    let mut need_break = false;
296                    if request_res.is_ok() {
297                        config.get_request_hook()(ctx.clone()).await;
298                    } else {
299                        need_break = true;
300                        config.get_closed_hook()(ctx.clone()).await;
301                    }
302                    let body: ResponseBody = ctx.get_response_body().await;
303                    let is_err: bool = self.broadcast_map.send(&key, body).is_err();
304                    config.get_sended_hook()(ctx.clone()).await;
305                    if need_break || is_err {
306                        break;
307                    }
308                },
309                msg_res = receiver.recv() => {
310                    if let Ok(msg) = msg_res {
311                        if ctx.set_response_body(msg).await.send_body().await.is_ok() {
312                            continue;
313                        }
314                    }
315                    break;
316                }
317            }
318        }
319        result_handle().await;
320    }
321}