1use serde::Serialize;
2use std::borrow::Borrow;
3use std::collections::{HashMap, HashSet};
4use std::hash::Hash;
5use std::sync::atomic::{AtomicU32, Ordering};
6
7use crate::libs::toolbox::{ArcToolbox, RequestContext};
8
9use super::{ConnectionId, WsResponseGeneric, WsStreamResponseGeneric};
10
11pub struct SubscribeContext<S> {
12 pub ctx: RequestContext,
13 pub stream_seq: AtomicU32,
14 pub settings: S,
15}
16
17pub struct SubscriptionManager<S, Key = ()> {
18 pub stream_code: u32,
19 pub subscribes: HashMap<ConnectionId, SubscribeContext<S>>,
20 pub mappings: HashMap<Key, HashSet<ConnectionId>>,
21}
22
23impl<S, Key: Eq + Hash> SubscriptionManager<S, Key> {
24 pub fn new(stream_code: u32) -> Self {
25 Self {
26 stream_code,
27 subscribes: Default::default(),
28 mappings: Default::default(),
29 }
30 }
31
32 pub fn subscribe(
33 &mut self,
34 ctx: RequestContext,
35 setting: S,
36 modify: impl FnOnce(&mut SubscribeContext<S>),
37 ) {
38 self.subscribe_with(ctx, vec![], || setting, modify)
39 }
40 pub fn subscribe_with_keys(
41 &mut self,
42 ctx: RequestContext,
43 keys: Vec<Key>,
44 setting: S,
45 modify: impl FnOnce(&mut SubscribeContext<S>),
46 ) {
47 self.subscribe_with(ctx, keys, || setting, modify)
48 }
49 pub fn subscribe_with(
50 &mut self,
51 ctx: RequestContext,
52 keys: Vec<Key>,
53 new: impl FnOnce() -> S,
54 modify: impl FnOnce(&mut SubscribeContext<S>),
55 ) {
56 self.subscribes
57 .entry(ctx.connection_id)
58 .and_modify(modify)
59 .or_insert_with(|| SubscribeContext {
60 ctx,
61 stream_seq: AtomicU32::new(0),
62 settings: new(),
63 });
64
65 for key in keys {
66 self.mappings
67 .entry(key)
68 .or_default()
69 .insert(ctx.connection_id);
70 }
71 }
72
73 pub fn unsubscribe(&mut self, connection_id: ConnectionId) {
74 self.subscribes.remove(&connection_id);
75 for pair in self.mappings.values_mut() {
76 pair.remove(&connection_id);
77 }
78 }
79 pub fn unsubscribe_with(
80 &mut self,
81 connection_id: ConnectionId,
82 remove: impl Fn(&mut SubscribeContext<S>) -> (bool, Vec<Key>),
83 ) {
84 let Some((remove1, keys)) = self.subscribes.get_mut(&connection_id).map(remove) else {
85 return;
86 };
87 if remove1 {
88 self.subscribes.remove(&connection_id);
89 }
90 for key in keys {
91 let remove = self
92 .mappings
93 .get_mut(&key)
94 .map(|set| {
95 set.remove(&connection_id);
96 set.is_empty()
97 })
98 .unwrap_or_default();
99 if remove {
100 self.mappings.remove(&key);
101 }
102 }
103 }
104
105 pub fn publish_to(
106 &mut self,
107 toolbox: &ArcToolbox,
108 connection_id: ConnectionId,
109 msg: &impl Serialize,
110 ) {
111 let mut dead_connection = None;
112
113 let Some(sub) = self.subscribes.get(&connection_id) else {
114 return;
115 };
116
117 let data = serde_json::to_value(msg).unwrap();
118
119 let msg = WsResponseGeneric::Stream(WsStreamResponseGeneric {
120 original_seq: sub.ctx.seq,
121 method: sub.ctx.method,
122 stream_seq: sub.stream_seq.fetch_add(1, Ordering::SeqCst),
123 stream_code: self.stream_code,
124 data: data.clone(),
125 });
126
127 if !toolbox.send(sub.ctx.connection_id, msg) {
128 dead_connection = Some(sub.ctx.connection_id);
129 }
130
131 if let Some(conn_id) = dead_connection {
132 self.unsubscribe(conn_id)
133 }
134 }
135 pub fn publish_to_key<Q>(&mut self, toolbox: &ArcToolbox, key: &Q, msg: &impl Serialize)
136 where
137 Key: Borrow<Q>,
138 Q: Eq + Hash + ?Sized,
139 {
140 let Some(conn_ids) = self.mappings.get(key).cloned() else {
141 return;
142 };
143
144 for conn_id in conn_ids {
145 self.publish_to(toolbox, conn_id, msg);
146 }
147 }
148 pub fn publish_to_keys<Q>(&mut self, toolbox: &ArcToolbox, keys: &[&Q], msg: &impl Serialize)
149 where
150 Key: Borrow<Q>,
151 Q: Eq + Hash + ?Sized,
152 {
153 let mut published = HashSet::new();
154 for key in keys {
155 let conn_ids = self.mappings.get(key).cloned();
156 if let Some(conn_ids) = conn_ids {
157 for conn_id in conn_ids.iter() {
158 if published.insert(*conn_id) {
160 self.publish_to(toolbox, *conn_id, msg);
161 }
162 }
163 }
164 }
165 }
166 pub fn publish_with_filter<M: Serialize>(
167 &mut self,
168 toolbox: &ArcToolbox,
169 filter: impl Fn(&SubscribeContext<S>) -> Option<M>,
170 ) {
171 let mut dead_connections = vec![];
172
173 for sub in self.subscribes.values() {
174 let Some(data) = filter(sub) else {
175 continue;
176 };
177 let data = serde_json::to_value(&data).unwrap();
178 let msg = WsResponseGeneric::Stream(WsStreamResponseGeneric {
179 original_seq: sub.ctx.seq,
180 method: sub.ctx.method,
181 stream_seq: sub.stream_seq.fetch_add(1, Ordering::SeqCst),
182 stream_code: self.stream_code,
183 data,
184 });
185
186 if !toolbox.send(sub.ctx.connection_id, msg) {
187 dead_connections.push(sub.ctx.connection_id);
188 }
189 }
190 for conn_id in dead_connections {
191 self.unsubscribe(conn_id);
192 }
193 }
194 pub fn publish_to_all(&mut self, toolbox: &ArcToolbox, msg: &impl Serialize) {
195 self.publish_with_filter(toolbox, |_| Some(msg))
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use std::sync::Arc;
202
203 use crate::libs::toolbox::{RequestContext, Toolbox};
204
205 pub(super) use super::*;
206
207 #[tokio::test]
208 async fn test_subscribe() {
209 let mut manager: SubscriptionManager<(), ()> = SubscriptionManager::new(0);
210
211 let ctx = RequestContext {
212 connection_id: 1,
213 ..RequestContext::empty()
214 };
215 manager.subscribe(ctx, (), |_| {});
216 assert_eq!(manager.subscribes.len(), 1);
217 assert_eq!(manager.mappings.len(), 0);
218 let toolbox = Arc::new(Toolbox::new());
219 manager.publish_to_all(&toolbox, &());
220 manager.publish_to_key(&toolbox, &(), &());
221 manager.publish_to_keys(&toolbox, &[], &());
222 manager.unsubscribe(ctx.connection_id);
223 assert_eq!(manager.subscribes.len(), 0);
224 assert_eq!(manager.mappings.len(), 0);
225 }
226}