endpoint_libs/libs/ws/
subs.rs

1use serde::Serialize;
2use std::borrow::Borrow;
3use std::collections::{HashMap, HashSet};
4use std::hash::Hash;
5use std::sync::atomic::{AtomicU32, Ordering};
6
7use crate::libs::toolbox::{ArcToolbox, RequestContext};
8
9use super::{ConnectionId, WsResponseGeneric, WsStreamResponseGeneric};
10
11pub struct SubscribeContext<S> {
12    pub ctx: RequestContext,
13    pub stream_seq: AtomicU32,
14    pub settings: S,
15}
16
17pub struct SubscriptionManager<S, Key = ()> {
18    pub stream_code: u32,
19    pub subscribes: HashMap<ConnectionId, SubscribeContext<S>>,
20    pub mappings: HashMap<Key, HashSet<ConnectionId>>,
21}
22
23impl<S, Key: Eq + Hash> SubscriptionManager<S, Key> {
24    pub fn new(stream_code: u32) -> Self {
25        Self {
26            stream_code,
27            subscribes: Default::default(),
28            mappings: Default::default(),
29        }
30    }
31
32    pub fn subscribe(&mut self, ctx: RequestContext, setting: S, modify: impl FnOnce(&mut SubscribeContext<S>)) {
33        self.subscribe_with(ctx, vec![], || setting, modify)
34    }
35    pub fn subscribe_with_keys(
36        &mut self,
37        ctx: RequestContext,
38        keys: Vec<Key>,
39        setting: S,
40        modify: impl FnOnce(&mut SubscribeContext<S>),
41    ) {
42        self.subscribe_with(ctx, keys, || setting, modify)
43    }
44    pub fn subscribe_with(
45        &mut self,
46        ctx: RequestContext,
47        keys: Vec<Key>,
48        new: impl FnOnce() -> S,
49        modify: impl FnOnce(&mut SubscribeContext<S>),
50    ) {
51        self.subscribes
52            .entry(ctx.connection_id)
53            .and_modify(modify)
54            .or_insert_with(|| SubscribeContext {
55                ctx,
56                stream_seq: AtomicU32::new(0),
57                settings: new(),
58            });
59
60        for key in keys {
61            self.mappings.entry(key).or_default().insert(ctx.connection_id);
62        }
63    }
64
65    pub fn unsubscribe(&mut self, connection_id: ConnectionId) {
66        self.subscribes.remove(&connection_id);
67        for pair in self.mappings.values_mut() {
68            pair.remove(&connection_id);
69        }
70    }
71    pub fn unsubscribe_with(
72        &mut self,
73        connection_id: ConnectionId,
74        remove: impl Fn(&mut SubscribeContext<S>) -> (bool, Vec<Key>),
75    ) {
76        let Some((remove1, keys)) = self.subscribes.get_mut(&connection_id).map(remove) else {
77            return;
78        };
79        if remove1 {
80            self.subscribes.remove(&connection_id);
81        }
82        for key in keys {
83            let remove = self
84                .mappings
85                .get_mut(&key)
86                .map(|set| {
87                    set.remove(&connection_id);
88                    set.is_empty()
89                })
90                .unwrap_or_default();
91            if remove {
92                self.mappings.remove(&key);
93            }
94        }
95    }
96
97    pub fn publish_to(&mut self, toolbox: &ArcToolbox, connection_id: ConnectionId, msg: &impl Serialize) {
98        let mut dead_connection = None;
99
100        let Some(sub) = self.subscribes.get(&connection_id) else {
101            return;
102        };
103
104        let data = serde_json::to_value(msg).unwrap();
105
106        let msg = WsResponseGeneric::Stream(WsStreamResponseGeneric {
107            original_seq: sub.ctx.seq,
108            method: sub.ctx.method,
109            stream_seq: sub.stream_seq.fetch_add(1, Ordering::SeqCst),
110            stream_code: self.stream_code,
111            data: data.clone(),
112        });
113
114        if !toolbox.send(sub.ctx.connection_id, msg) {
115            dead_connection = Some(sub.ctx.connection_id);
116        }
117
118        if let Some(conn_id) = dead_connection {
119            self.unsubscribe(conn_id)
120        }
121    }
122    pub fn publish_to_key<Q>(&mut self, toolbox: &ArcToolbox, key: &Q, msg: &impl Serialize)
123    where
124        Key: Borrow<Q>,
125        Q: Eq + Hash + ?Sized,
126    {
127        let Some(conn_ids) = self.mappings.get(key).cloned() else {
128            return;
129        };
130
131        for conn_id in conn_ids {
132            self.publish_to(toolbox, conn_id, msg);
133        }
134    }
135    pub fn publish_to_keys<Q>(&mut self, toolbox: &ArcToolbox, keys: &[&Q], msg: &impl Serialize)
136    where
137        Key: Borrow<Q>,
138        Q: Eq + Hash + ?Sized,
139    {
140        let mut published = HashSet::new();
141        for key in keys {
142            let conn_ids = self.mappings.get(key).cloned();
143            if let Some(conn_ids) = conn_ids {
144                for conn_id in conn_ids.iter() {
145                    // if newly inserted
146                    if published.insert(*conn_id) {
147                        self.publish_to(toolbox, *conn_id, msg);
148                    }
149                }
150            }
151        }
152    }
153    pub fn publish_with_filter<M: Serialize>(
154        &mut self,
155        toolbox: &ArcToolbox,
156        filter: impl Fn(&SubscribeContext<S>) -> Option<M>,
157    ) {
158        let mut dead_connections = vec![];
159
160        for sub in self.subscribes.values() {
161            let Some(data) = filter(sub) else {
162                continue;
163            };
164            let data = serde_json::to_value(&data).unwrap();
165            let msg = WsResponseGeneric::Stream(WsStreamResponseGeneric {
166                original_seq: sub.ctx.seq,
167                method: sub.ctx.method,
168                stream_seq: sub.stream_seq.fetch_add(1, Ordering::SeqCst),
169                stream_code: self.stream_code,
170                data,
171            });
172
173            if !toolbox.send(sub.ctx.connection_id, msg) {
174                dead_connections.push(sub.ctx.connection_id);
175            }
176        }
177        for conn_id in dead_connections {
178            self.unsubscribe(conn_id);
179        }
180    }
181    pub fn publish_to_all(&mut self, toolbox: &ArcToolbox, msg: &impl Serialize) {
182        self.publish_with_filter(toolbox, |_| Some(msg))
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use std::sync::Arc;
189
190
191    use crate::libs::toolbox::{RequestContext, Toolbox};
192
193    pub(super) use super::*;
194
195    #[tokio::test]
196    async fn test_subscribe() {
197        let mut manager: SubscriptionManager<(), ()> = SubscriptionManager::new(0);
198
199        let ctx = RequestContext {
200            connection_id: 1,
201            ..RequestContext::empty()
202        };
203        manager.subscribe(ctx, (), |_| {});
204        assert_eq!(manager.subscribes.len(), 1);
205        assert_eq!(manager.mappings.len(), 0);
206        let toolbox = Arc::new(Toolbox::new());
207        manager.publish_to_all(&toolbox, &());
208        manager.publish_to_key(&toolbox, &(), &());
209        manager.publish_to_keys(&toolbox, &[], &());
210        manager.unsubscribe(ctx.connection_id);
211        assert_eq!(manager.subscribes.len(), 0);
212        assert_eq!(manager.mappings.len(), 0);
213    }
214}