endpoint_libs/libs/ws/
subs.rs

1use serde::Serialize;
2use std::borrow::Borrow;
3use std::collections::{HashMap, HashSet};
4use std::hash::Hash;
5use std::sync::atomic::{AtomicU32, Ordering};
6
7use crate::libs::toolbox::{ArcToolbox, RequestContext};
8
9use super::{ConnectionId, WsResponseGeneric, WsStreamResponseGeneric};
10
11pub struct SubscribeContext<S> {
12    pub ctx: RequestContext,
13    pub stream_seq: AtomicU32,
14    pub settings: S,
15}
16
17pub struct SubscriptionManager<S, Key = ()> {
18    pub stream_code: u32,
19    pub subscribes: HashMap<ConnectionId, SubscribeContext<S>>,
20    pub mappings: HashMap<Key, HashSet<ConnectionId>>,
21}
22
23impl<S, Key: Eq + Hash> SubscriptionManager<S, Key> {
24    pub fn new(stream_code: u32) -> Self {
25        Self {
26            stream_code,
27            subscribes: Default::default(),
28            mappings: Default::default(),
29        }
30    }
31
32    pub fn subscribe(
33        &mut self,
34        ctx: RequestContext,
35        setting: S,
36        modify: impl FnOnce(&mut SubscribeContext<S>),
37    ) {
38        self.subscribe_with(ctx, vec![], || setting, modify)
39    }
40    pub fn subscribe_with_keys(
41        &mut self,
42        ctx: RequestContext,
43        keys: Vec<Key>,
44        setting: S,
45        modify: impl FnOnce(&mut SubscribeContext<S>),
46    ) {
47        self.subscribe_with(ctx, keys, || setting, modify)
48    }
49    pub fn subscribe_with(
50        &mut self,
51        ctx: RequestContext,
52        keys: Vec<Key>,
53        new: impl FnOnce() -> S,
54        modify: impl FnOnce(&mut SubscribeContext<S>),
55    ) {
56        self.subscribes
57            .entry(ctx.connection_id)
58            .and_modify(modify)
59            .or_insert_with(|| SubscribeContext {
60                ctx,
61                stream_seq: AtomicU32::new(0),
62                settings: new(),
63            });
64
65        for key in keys {
66            self.mappings
67                .entry(key)
68                .or_default()
69                .insert(ctx.connection_id);
70        }
71    }
72
73    pub fn unsubscribe(&mut self, connection_id: ConnectionId) {
74        self.subscribes.remove(&connection_id);
75        for pair in self.mappings.values_mut() {
76            pair.remove(&connection_id);
77        }
78    }
79    pub fn unsubscribe_with(
80        &mut self,
81        connection_id: ConnectionId,
82        remove: impl Fn(&mut SubscribeContext<S>) -> (bool, Vec<Key>),
83    ) {
84        let Some((remove1, keys)) = self.subscribes.get_mut(&connection_id).map(remove) else {
85            return;
86        };
87        if remove1 {
88            self.subscribes.remove(&connection_id);
89        }
90        for key in keys {
91            let remove = self
92                .mappings
93                .get_mut(&key)
94                .map(|set| {
95                    set.remove(&connection_id);
96                    set.is_empty()
97                })
98                .unwrap_or_default();
99            if remove {
100                self.mappings.remove(&key);
101            }
102        }
103    }
104
105    pub fn publish_to(
106        &mut self,
107        toolbox: &ArcToolbox,
108        connection_id: ConnectionId,
109        msg: &impl Serialize,
110    ) {
111        let mut dead_connection = None;
112
113        let Some(sub) = self.subscribes.get(&connection_id) else {
114            return;
115        };
116
117        let data = serde_json::to_value(msg).unwrap();
118
119        let msg = WsResponseGeneric::Stream(WsStreamResponseGeneric {
120            original_seq: sub.ctx.seq,
121            method: sub.ctx.method,
122            stream_seq: sub.stream_seq.fetch_add(1, Ordering::SeqCst),
123            stream_code: self.stream_code,
124            data: data.clone(),
125        });
126
127        if !toolbox.send(sub.ctx.connection_id, msg) {
128            dead_connection = Some(sub.ctx.connection_id);
129        }
130
131        if let Some(conn_id) = dead_connection {
132            self.unsubscribe(conn_id)
133        }
134    }
135    pub fn publish_to_key<Q>(&mut self, toolbox: &ArcToolbox, key: &Q, msg: &impl Serialize)
136    where
137        Key: Borrow<Q>,
138        Q: Eq + Hash + ?Sized,
139    {
140        let Some(conn_ids) = self.mappings.get(key).cloned() else {
141            return;
142        };
143
144        for conn_id in conn_ids {
145            self.publish_to(toolbox, conn_id, msg);
146        }
147    }
148    pub fn publish_to_keys<Q>(&mut self, toolbox: &ArcToolbox, keys: &[&Q], msg: &impl Serialize)
149    where
150        Key: Borrow<Q>,
151        Q: Eq + Hash + ?Sized,
152    {
153        let mut published = HashSet::new();
154        for key in keys {
155            let conn_ids = self.mappings.get(key).cloned();
156            if let Some(conn_ids) = conn_ids {
157                for conn_id in conn_ids.iter() {
158                    // if newly inserted
159                    if published.insert(*conn_id) {
160                        self.publish_to(toolbox, *conn_id, msg);
161                    }
162                }
163            }
164        }
165    }
166    pub fn publish_with_filter<M: Serialize>(
167        &mut self,
168        toolbox: &ArcToolbox,
169        filter: impl Fn(&SubscribeContext<S>) -> Option<M>,
170    ) {
171        let mut dead_connections = vec![];
172
173        for sub in self.subscribes.values() {
174            let Some(data) = filter(sub) else {
175                continue;
176            };
177            let data = serde_json::to_value(&data).unwrap();
178            let msg = WsResponseGeneric::Stream(WsStreamResponseGeneric {
179                original_seq: sub.ctx.seq,
180                method: sub.ctx.method,
181                stream_seq: sub.stream_seq.fetch_add(1, Ordering::SeqCst),
182                stream_code: self.stream_code,
183                data,
184            });
185
186            if !toolbox.send(sub.ctx.connection_id, msg) {
187                dead_connections.push(sub.ctx.connection_id);
188            }
189        }
190        for conn_id in dead_connections {
191            self.unsubscribe(conn_id);
192        }
193    }
194    pub fn publish_to_all(&mut self, toolbox: &ArcToolbox, msg: &impl Serialize) {
195        self.publish_with_filter(toolbox, |_| Some(msg))
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use std::sync::Arc;
202
203    use crate::libs::toolbox::{RequestContext, Toolbox};
204
205    pub(super) use super::*;
206
207    #[tokio::test]
208    async fn test_subscribe() {
209        let mut manager: SubscriptionManager<(), ()> = SubscriptionManager::new(0);
210
211        let ctx = RequestContext {
212            connection_id: 1,
213            ..RequestContext::empty()
214        };
215        manager.subscribe(ctx, (), |_| {});
216        assert_eq!(manager.subscribes.len(), 1);
217        assert_eq!(manager.mappings.len(), 0);
218        let toolbox = Arc::new(Toolbox::new());
219        manager.publish_to_all(&toolbox, &());
220        manager.publish_to_key(&toolbox, &(), &());
221        manager.publish_to_keys(&toolbox, &[], &());
222        manager.unsubscribe(ctx.connection_id);
223        assert_eq!(manager.subscribes.len(), 0);
224        assert_eq!(manager.mappings.len(), 0);
225    }
226}