1use serde::Serialize;
2use std::borrow::Borrow;
3use std::collections::{HashMap, HashSet};
4use std::hash::Hash;
5use std::sync::atomic::{AtomicU32, Ordering};
6
7use crate::libs::toolbox::{ArcToolbox, RequestContext};
8
9use super::{ConnectionId, WsResponseGeneric, WsStreamResponseGeneric};
10
11pub struct SubscribeContext<S> {
12 pub ctx: RequestContext,
13 pub stream_seq: AtomicU32,
14 pub settings: S,
15}
16
17pub struct SubscriptionManager<S, Key = ()> {
18 pub stream_code: u32,
19 pub subscribes: HashMap<ConnectionId, SubscribeContext<S>>,
20 pub mappings: HashMap<Key, HashSet<ConnectionId>>,
21}
22
23impl<S, Key: Eq + Hash> SubscriptionManager<S, Key> {
24 pub fn new(stream_code: u32) -> Self {
25 Self {
26 stream_code,
27 subscribes: Default::default(),
28 mappings: Default::default(),
29 }
30 }
31
32 pub fn subscribe(&mut self, ctx: RequestContext, setting: S, modify: impl FnOnce(&mut SubscribeContext<S>)) {
33 self.subscribe_with(ctx, vec![], || setting, modify)
34 }
35 pub fn subscribe_with_keys(
36 &mut self,
37 ctx: RequestContext,
38 keys: Vec<Key>,
39 setting: S,
40 modify: impl FnOnce(&mut SubscribeContext<S>),
41 ) {
42 self.subscribe_with(ctx, keys, || setting, modify)
43 }
44 pub fn subscribe_with(
45 &mut self,
46 ctx: RequestContext,
47 keys: Vec<Key>,
48 new: impl FnOnce() -> S,
49 modify: impl FnOnce(&mut SubscribeContext<S>),
50 ) {
51 self.subscribes
52 .entry(ctx.connection_id)
53 .and_modify(modify)
54 .or_insert_with(|| SubscribeContext {
55 ctx,
56 stream_seq: AtomicU32::new(0),
57 settings: new(),
58 });
59
60 for key in keys {
61 self.mappings.entry(key).or_default().insert(ctx.connection_id);
62 }
63 }
64
65 pub fn unsubscribe(&mut self, connection_id: ConnectionId) {
66 self.subscribes.remove(&connection_id);
67 for pair in self.mappings.values_mut() {
68 pair.remove(&connection_id);
69 }
70 }
71 pub fn unsubscribe_with(
72 &mut self,
73 connection_id: ConnectionId,
74 remove: impl Fn(&mut SubscribeContext<S>) -> (bool, Vec<Key>),
75 ) {
76 let Some((remove1, keys)) = self.subscribes.get_mut(&connection_id).map(remove) else {
77 return;
78 };
79 if remove1 {
80 self.subscribes.remove(&connection_id);
81 }
82 for key in keys {
83 let remove = self
84 .mappings
85 .get_mut(&key)
86 .map(|set| {
87 set.remove(&connection_id);
88 set.is_empty()
89 })
90 .unwrap_or_default();
91 if remove {
92 self.mappings.remove(&key);
93 }
94 }
95 }
96
97 pub fn publish_to(&mut self, toolbox: &ArcToolbox, connection_id: ConnectionId, msg: &impl Serialize) {
98 let mut dead_connection = None;
99
100 let Some(sub) = self.subscribes.get(&connection_id) else {
101 return;
102 };
103
104 let data = serde_json::to_value(msg).unwrap();
105
106 let msg = WsResponseGeneric::Stream(WsStreamResponseGeneric {
107 original_seq: sub.ctx.seq,
108 method: sub.ctx.method,
109 stream_seq: sub.stream_seq.fetch_add(1, Ordering::SeqCst),
110 stream_code: self.stream_code,
111 data: data.clone(),
112 });
113
114 if !toolbox.send(sub.ctx.connection_id, msg) {
115 dead_connection = Some(sub.ctx.connection_id);
116 }
117
118 if let Some(conn_id) = dead_connection {
119 self.unsubscribe(conn_id)
120 }
121 }
122 pub fn publish_to_key<Q>(&mut self, toolbox: &ArcToolbox, key: &Q, msg: &impl Serialize)
123 where
124 Key: Borrow<Q>,
125 Q: Eq + Hash + ?Sized,
126 {
127 let Some(conn_ids) = self.mappings.get(key).cloned() else {
128 return;
129 };
130
131 for conn_id in conn_ids {
132 self.publish_to(toolbox, conn_id, msg);
133 }
134 }
135 pub fn publish_to_keys<Q>(&mut self, toolbox: &ArcToolbox, keys: &[&Q], msg: &impl Serialize)
136 where
137 Key: Borrow<Q>,
138 Q: Eq + Hash + ?Sized,
139 {
140 let mut published = HashSet::new();
141 for key in keys {
142 let conn_ids = self.mappings.get(key).cloned();
143 if let Some(conn_ids) = conn_ids {
144 for conn_id in conn_ids.iter() {
145 if published.insert(*conn_id) {
147 self.publish_to(toolbox, *conn_id, msg);
148 }
149 }
150 }
151 }
152 }
153 pub fn publish_with_filter<M: Serialize>(
154 &mut self,
155 toolbox: &ArcToolbox,
156 filter: impl Fn(&SubscribeContext<S>) -> Option<M>,
157 ) {
158 let mut dead_connections = vec![];
159
160 for sub in self.subscribes.values() {
161 let Some(data) = filter(sub) else {
162 continue;
163 };
164 let data = serde_json::to_value(&data).unwrap();
165 let msg = WsResponseGeneric::Stream(WsStreamResponseGeneric {
166 original_seq: sub.ctx.seq,
167 method: sub.ctx.method,
168 stream_seq: sub.stream_seq.fetch_add(1, Ordering::SeqCst),
169 stream_code: self.stream_code,
170 data,
171 });
172
173 if !toolbox.send(sub.ctx.connection_id, msg) {
174 dead_connections.push(sub.ctx.connection_id);
175 }
176 }
177 for conn_id in dead_connections {
178 self.unsubscribe(conn_id);
179 }
180 }
181 pub fn publish_to_all(&mut self, toolbox: &ArcToolbox, msg: &impl Serialize) {
182 self.publish_with_filter(toolbox, |_| Some(msg))
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use std::sync::Arc;
189
190
191 use crate::libs::toolbox::{RequestContext, Toolbox};
192
193 pub(super) use super::*;
194
195 #[tokio::test]
196 async fn test_subscribe() {
197 let mut manager: SubscriptionManager<(), ()> = SubscriptionManager::new(0);
198
199 let ctx = RequestContext {
200 connection_id: 1,
201 ..RequestContext::empty()
202 };
203 manager.subscribe(ctx, (), |_| {});
204 assert_eq!(manager.subscribes.len(), 1);
205 assert_eq!(manager.mappings.len(), 0);
206 let toolbox = Arc::new(Toolbox::new());
207 manager.publish_to_all(&toolbox, &());
208 manager.publish_to_key(&toolbox, &(), &());
209 manager.publish_to_keys(&toolbox, &[], &());
210 manager.unsubscribe(ctx.connection_id);
211 assert_eq!(manager.subscribes.len(), 0);
212 assert_eq!(manager.mappings.len(), 0);
213 }
214}