1mod build_router;
2mod builder;
3mod subscription_task;
4
5pub use {build_router::*, builder::*};
6
7use crate::{
8 messages::SubscriptionMessage,
9 types::{ChannelId, ClientId, ClientReceiver, ClientSender, CookieId},
10 utils::{ChannelNameValidator, WildNamesCache},
11 CometdCustomDataSender, CometdEventReceiver, Event, SendError,
12};
13use ahash::{HashMap, HashSet, HashSetExt as _};
14use async_broadcast::{InactiveReceiver, Sender};
15use core::{fmt::Debug, ops::Deref};
16use serde::Serialize;
17use serde_json::json;
18use std::{collections::hash_map::Entry, sync::Arc};
19use tokio::sync::{mpsc, RwLock};
20
21#[derive(Debug)]
23pub struct LongPollingServiceContext<AdditionalData, CustomData> {
24 pub(crate) tx: Sender<Arc<Event<AdditionalData, CustomData>>>,
25 pub(crate) inactive_rx: InactiveReceiver<Arc<Event<AdditionalData, CustomData>>>,
26
27 pub(crate) wildnames_cache: WildNamesCache,
28 pub(crate) channel_name_validator: ChannelNameValidator,
29 pub(crate) consts: LongPollingServiceContextConsts,
30 pub(crate) channels_data: RwLock<HashMap<ChannelId, Channel>>,
31 client_id_senders: Arc<RwLock<HashMap<ClientId, ClientSender>>>,
32}
33
34#[derive(Debug)]
35pub(crate) struct Channel {
36 client_ids: HashSet<ClientId>,
37 tx: mpsc::Sender<SubscriptionMessage>,
38}
39
40impl Channel {
41 #[inline(always)]
42 const fn client_ids(&self) -> &HashSet<ClientId> {
43 &self.client_ids
44 }
45
46 #[inline(always)]
47 pub(crate) const fn tx(&self) -> &mpsc::Sender<SubscriptionMessage> {
48 &self.tx
49 }
50}
51
52impl<AdditionalData, CustomData> LongPollingServiceContext<AdditionalData, CustomData> {
53 pub fn rx(&self) -> CometdEventReceiver<AdditionalData, CustomData> {
67 CometdEventReceiver(self.inactive_rx.activate_cloned())
68 }
69
70 pub fn tx(&self) -> CometdCustomDataSender<AdditionalData, CustomData> {
84 CometdCustomDataSender(self.tx.clone())
85 }
86
87 #[inline]
124 pub async fn send(
125 &self,
126 channel: &str,
127 message: impl Debug + Serialize,
128 ) -> Result<(), SendError> {
129 self.channel_name_validator
130 .validate_send_channel_name(channel)
131 .then_some(())
132 .ok_or(SendError::InvalidChannel)?;
133
134 let subscription_message = SubscriptionMessage {
135 channel: channel.to_owned(),
136 msg: json!(message),
137 };
138 let wildnames = self.wildnames_cache.fetch_wildnames(channel);
139 let read_guard = self.channels_data.read().await;
140 for channel in core::iter::once(channel).chain(wildnames.iter().map(String::deref)) {
141 if let Some(tx) = read_guard.get(channel).map(Channel::tx) {
142 tx.send(subscription_message.clone()).await?;
143 } else {
144 tracing::warn!(
145 channel = channel,
146 "No `{channel}` channel was found for message: `{message:?}`."
147 );
148 }
149 }
150
151 Ok(())
152 }
153
154 #[inline]
156 pub async fn send_to_client(
157 &self,
158 channel: &str,
159 client_id: &ClientId,
160 msg: impl Debug + Serialize,
161 ) -> Result<(), SendError> {
162 self.channel_name_validator
163 .validate_send_channel_name(channel)
164 .then_some(())
165 .ok_or(SendError::InvalidChannel)?;
166
167 if let Some(tx) = self.client_id_senders.read().await.get(client_id) {
168 tx.send(SubscriptionMessage {
169 channel: channel.to_owned(),
170 msg: json!(msg),
171 })
172 .await?;
173
174 Ok(())
175 } else {
176 tracing::warn!(
177 client_id = %client_id,
178 "No `{client_id}` client was found for message: `{msg:?}`."
179 );
180
181 Err(SendError::ClientWasntFound(*client_id))
182 }
183 }
184
185 pub(crate) async fn register(self: &Arc<Self>, cookie_id: CookieId) -> Option<ClientId>
186 where
187 AdditionalData: Send + Sync + 'static,
188 CustomData: Send + Sync + 'static,
189 {
190 let client_id = {
191 let mut client_id_channels_write_guard = self.client_id_senders.write().await;
192
193 let client_id = ClientId::gen();
194 let (tx, rx) = mpsc::channel(self.consts.client_channel_capacity);
195
196 match client_id_channels_write_guard.entry(client_id) {
197 Entry::Occupied(_) => return None,
198 Entry::Vacant(v) => {
199 v.insert(ClientSender::create(
200 Arc::clone(self),
201 cookie_id,
202 client_id,
203 self.consts.max_interval,
204 tx,
205 rx,
206 ));
207 }
208 }
209
210 Some(client_id)
211 }?;
212
213 tracing::info!(
214 client_id = %client_id,
215 "New client was registered with clientId `{client_id}`."
216 );
217
218 Some(client_id)
219 }
220
221 pub(crate) async fn subscribe(self: &Arc<Self>, client_id: ClientId, channels: &[String])
222 where
223 AdditionalData: Send + Sync + 'static,
224 CustomData: Send + Sync + 'static,
225 {
226 let mut channels_data_write_guard = self.channels_data.write().await;
227 for channel in channels {
228 match channels_data_write_guard.entry(channel.clone()) {
229 Entry::Occupied(o) => o.into_mut(),
230 Entry::Vacant(v) => {
231 let (tx, rx) = mpsc::channel(self.consts.subscription_channel_capacity);
232
233 subscription_task::spawn(channel.clone(), rx, Arc::clone(self));
234 tracing::info!(
235 channel = channel,
236 "New subscription ({channel}) channel was registered."
237 );
238
239 v.insert(Channel {
240 client_ids: Default::default(),
241 tx,
242 })
243 }
244 }
245 .client_ids
246 .insert(client_id);
247 }
248
249 tracing::info!(
250 client_id = %client_id,
251 channels = debug(channels),
252 "Client with clientId `{client_id}` subscribe on `{channels:?}` channels."
253 );
254 }
255
256 #[inline]
259 pub async fn unsubscribe(self: &Arc<Self>, client_id: ClientId) {
260 tokio::join!(
261 self.remove_client_id_from_subscriptions(&client_id),
262 self.remove_client_tx(&client_id),
263 );
264
265 let _ = self
266 .tx
267 .broadcast(Arc::new(Event::SessionRemoved { client_id }))
268 .await;
269 }
270
271 #[inline]
272 async fn remove_client_id_from_subscriptions(&self, client_id: &ClientId) {
273 let mut removed_channels = HashSet::new();
276
277 self.channels_data.write().await.retain(
278 |channel,
279 &mut Channel {
280 ref mut client_ids,
281 tx: _,
282 }| {
283 if client_ids.remove(client_id) {
284 tracing::info!(
285 client_id = %client_id,
286 channel = channel,
287 "Client `{client_id}` was unsubscribed from channel `{channel}."
288 );
289 }
290
291 if client_ids.is_empty() {
292 tracing::info!(
293 channel = channel,
294 "Channel `{channel}` have no active subscriber. Eliminate channel."
295 );
296 removed_channels.insert(channel.clone());
297 false
298 } else {
299 true
300 }
301 },
302 );
303
304 self.wildnames_cache.remove_wildnames(removed_channels);
305 }
306
307 #[inline]
308 async fn remove_client_tx(&self, client_id: &ClientId) {
309 if self
310 .client_id_senders
311 .write()
312 .await
313 .remove(client_id)
314 .is_some()
315 {
316 tracing::info!(
317 client_id = %client_id,
318 "Client `{client_id}` was unsubscribed."
319 );
320 } else {
321 tracing::warn!(
322 client_id = %client_id,
323 "Can't find client `{client_id}`. Can't unsubscribed."
324 );
325 }
326 }
327
328 #[inline]
329 pub(crate) async fn check_client(
330 &self,
331 cookie_id: CookieId,
332 client_id: &ClientId,
333 ) -> Option<()> {
334 self.client_id_senders
335 .read()
336 .await
337 .get(client_id)
338 .map(ClientSender::cookie_id)
339 .eq(&Some(cookie_id))
340 .then_some(())
341 }
342
343 #[inline]
344 pub(crate) async fn get_client_receiver(&self, client_id: &ClientId) -> Option<ClientReceiver> {
345 self.client_id_senders
346 .read()
347 .await
348 .get(client_id)
349 .map(ClientSender::subscribe)
350 }
351}