axum_cometd/
context.rs

1mod build_router;
2mod builder;
3mod subscription_task;
4
5pub use {build_router::*, builder::*};
6
7use crate::{
8    messages::SubscriptionMessage,
9    types::{ChannelId, ClientId, ClientReceiver, ClientSender, CookieId},
10    utils::{ChannelNameValidator, WildNamesCache},
11    CometdCustomDataSender, CometdEventReceiver, Event, SendError,
12};
13use ahash::{HashMap, HashSet, HashSetExt as _};
14use async_broadcast::{InactiveReceiver, Sender};
15use core::{fmt::Debug, ops::Deref};
16use serde::Serialize;
17use serde_json::json;
18use std::{collections::hash_map::Entry, sync::Arc};
19use tokio::sync::{mpsc, RwLock};
20
21/// Context for sending messages to channels.
22#[derive(Debug)]
23pub struct LongPollingServiceContext<AdditionalData, CustomData> {
24    pub(crate) tx: Sender<Arc<Event<AdditionalData, CustomData>>>,
25    pub(crate) inactive_rx: InactiveReceiver<Arc<Event<AdditionalData, CustomData>>>,
26
27    pub(crate) wildnames_cache: WildNamesCache,
28    pub(crate) channel_name_validator: ChannelNameValidator,
29    pub(crate) consts: LongPollingServiceContextConsts,
30    pub(crate) channels_data: RwLock<HashMap<ChannelId, Channel>>,
31    client_id_senders: Arc<RwLock<HashMap<ClientId, ClientSender>>>,
32}
33
34#[derive(Debug)]
35pub(crate) struct Channel {
36    client_ids: HashSet<ClientId>,
37    tx: mpsc::Sender<SubscriptionMessage>,
38}
39
40impl Channel {
41    #[inline(always)]
42    const fn client_ids(&self) -> &HashSet<ClientId> {
43        &self.client_ids
44    }
45
46    #[inline(always)]
47    pub(crate) const fn tx(&self) -> &mpsc::Sender<SubscriptionMessage> {
48        &self.tx
49    }
50}
51
52impl<AdditionalData, CustomData> LongPollingServiceContext<AdditionalData, CustomData> {
53    /// Get new events receiver.
54    ///
55    /// # Example
56    /// ```rust,no_run
57    /// # async {
58    /// # let context = axum_cometd::LongPollingServiceContextBuilder::new().build::<(), ()>();
59    ///     let mut rx = context.rx();
60    ///     
61    ///     while let Some(event) = rx.recv().await {
62    ///         println!("Got event: `{event:?}`");
63    ///     }
64    /// # };
65    /// ```
66    pub fn rx(&self) -> CometdEventReceiver<AdditionalData, CustomData> {
67        CometdEventReceiver(self.inactive_rx.activate_cloned())
68    }
69
70    /// Get new events sender.
71    ///
72    /// # Example
73    /// ```rust,no_run
74    /// # use std::sync::Arc;
75    /// # use axum_cometd::Event;
76    ///  async {
77    /// # let context = axum_cometd::LongPollingServiceContextBuilder::new().build::<(), &'static str>();
78    ///     let tx = context.tx();
79    ///     
80    ///     tx.send("hello").await;
81    /// # };
82    /// ```
83    pub fn tx(&self) -> CometdCustomDataSender<AdditionalData, CustomData> {
84        CometdCustomDataSender(self.tx.clone())
85    }
86
87    /// Send message to channel.
88    ///
89    /// # Example
90    /// ```rust,no_run
91    /// # use core::time::Duration;
92    /// #[derive(Debug, Clone, serde::Serialize)]
93    ///     struct Data<'a> {
94    ///         msg: std::borrow::Cow<'a, str>,
95    ///         r#bool: bool,
96    ///         num: u64,
97    ///     }
98    ///
99    /// # async {
100    ///     let context = axum_cometd::LongPollingServiceContextBuilder::new()
101    ///         .timeout(Duration::from_secs(1))
102    ///         .max_interval(Duration::from_secs(2))
103    ///         .client_channel_capacity(10_000)
104    ///         .subscription_channel_capacity(20_000)
105    ///         .build::<(), ()>();
106    ///
107    ///     loop {
108    ///         context
109    ///             .send(
110    ///                 "/topic",
111    ///                 Data {
112    ///                     msg: "Hello World!!!".into(),
113    ///                     r#bool: true,
114    ///                     num: u64::MAX,
115    ///                 },
116    ///             )
117    ///             .await?;
118    ///         tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
119    ///     }
120    /// # Ok::<(), axum_cometd::SendError>(())
121    /// # };
122    /// ```
123    #[inline]
124    pub async fn send(
125        &self,
126        channel: &str,
127        message: impl Debug + Serialize,
128    ) -> Result<(), SendError> {
129        self.channel_name_validator
130            .validate_send_channel_name(channel)
131            .then_some(())
132            .ok_or(SendError::InvalidChannel)?;
133
134        let subscription_message = SubscriptionMessage {
135            channel: channel.to_owned(),
136            msg: json!(message),
137        };
138        let wildnames = self.wildnames_cache.fetch_wildnames(channel);
139        let read_guard = self.channels_data.read().await;
140        for channel in core::iter::once(channel).chain(wildnames.iter().map(String::deref)) {
141            if let Some(tx) = read_guard.get(channel).map(Channel::tx) {
142                tx.send(subscription_message.clone()).await?;
143            } else {
144                tracing::warn!(
145                    channel = channel,
146                    "No `{channel}` channel was found for message: `{message:?}`."
147                );
148            }
149        }
150
151        Ok(())
152    }
153
154    /// Send message direct to client.
155    #[inline]
156    pub async fn send_to_client(
157        &self,
158        channel: &str,
159        client_id: &ClientId,
160        msg: impl Debug + Serialize,
161    ) -> Result<(), SendError> {
162        self.channel_name_validator
163            .validate_send_channel_name(channel)
164            .then_some(())
165            .ok_or(SendError::InvalidChannel)?;
166
167        if let Some(tx) = self.client_id_senders.read().await.get(client_id) {
168            tx.send(SubscriptionMessage {
169                channel: channel.to_owned(),
170                msg: json!(msg),
171            })
172            .await?;
173
174            Ok(())
175        } else {
176            tracing::warn!(
177                client_id = %client_id,
178                "No `{client_id}` client was found for message: `{msg:?}`."
179            );
180
181            Err(SendError::ClientWasntFound(*client_id))
182        }
183    }
184
185    pub(crate) async fn register(self: &Arc<Self>, cookie_id: CookieId) -> Option<ClientId>
186    where
187        AdditionalData: Send + Sync + 'static,
188        CustomData: Send + Sync + 'static,
189    {
190        let client_id = {
191            let mut client_id_channels_write_guard = self.client_id_senders.write().await;
192
193            let client_id = ClientId::gen();
194            let (tx, rx) = mpsc::channel(self.consts.client_channel_capacity);
195
196            match client_id_channels_write_guard.entry(client_id) {
197                Entry::Occupied(_) => return None,
198                Entry::Vacant(v) => {
199                    v.insert(ClientSender::create(
200                        Arc::clone(self),
201                        cookie_id,
202                        client_id,
203                        self.consts.max_interval,
204                        tx,
205                        rx,
206                    ));
207                }
208            }
209
210            Some(client_id)
211        }?;
212
213        tracing::info!(
214            client_id = %client_id,
215            "New client was registered with clientId `{client_id}`."
216        );
217
218        Some(client_id)
219    }
220
221    pub(crate) async fn subscribe(self: &Arc<Self>, client_id: ClientId, channels: &[String])
222    where
223        AdditionalData: Send + Sync + 'static,
224        CustomData: Send + Sync + 'static,
225    {
226        let mut channels_data_write_guard = self.channels_data.write().await;
227        for channel in channels {
228            match channels_data_write_guard.entry(channel.clone()) {
229                Entry::Occupied(o) => o.into_mut(),
230                Entry::Vacant(v) => {
231                    let (tx, rx) = mpsc::channel(self.consts.subscription_channel_capacity);
232
233                    subscription_task::spawn(channel.clone(), rx, Arc::clone(self));
234                    tracing::info!(
235                        channel = channel,
236                        "New subscription ({channel}) channel was registered."
237                    );
238
239                    v.insert(Channel {
240                        client_ids: Default::default(),
241                        tx,
242                    })
243                }
244            }
245            .client_ids
246            .insert(client_id);
247        }
248
249        tracing::info!(
250            client_id = %client_id,
251            channels = debug(channels),
252            "Client with clientId `{client_id}` subscribe on `{channels:?}` channels."
253        );
254    }
255
256    // TODO: Spawn task and send unsubscribe command through channel?
257    /// Remove client.
258    #[inline]
259    pub async fn unsubscribe(self: &Arc<Self>, client_id: ClientId) {
260        tokio::join!(
261            self.remove_client_id_from_subscriptions(&client_id),
262            self.remove_client_tx(&client_id),
263        );
264
265        let _ = self
266            .tx
267            .broadcast(Arc::new(Event::SessionRemoved { client_id }))
268            .await;
269    }
270
271    #[inline]
272    async fn remove_client_id_from_subscriptions(&self, client_id: &ClientId) {
273        // TODO: drain_filter: https://github.com/rust-lang/rust/issues/59618
274        // TODO: Replace on LinkedList?
275        let mut removed_channels = HashSet::new();
276
277        self.channels_data.write().await.retain(
278            |channel,
279             &mut Channel {
280                 ref mut client_ids,
281                 tx: _,
282             }| {
283                if client_ids.remove(client_id) {
284                    tracing::info!(
285                        client_id = %client_id,
286                        channel = channel,
287                        "Client `{client_id}` was unsubscribed from channel `{channel}."
288                    );
289                }
290
291                if client_ids.is_empty() {
292                    tracing::info!(
293                        channel = channel,
294                        "Channel `{channel}` have no active subscriber. Eliminate channel."
295                    );
296                    removed_channels.insert(channel.clone());
297                    false
298                } else {
299                    true
300                }
301            },
302        );
303
304        self.wildnames_cache.remove_wildnames(removed_channels);
305    }
306
307    #[inline]
308    async fn remove_client_tx(&self, client_id: &ClientId) {
309        if self
310            .client_id_senders
311            .write()
312            .await
313            .remove(client_id)
314            .is_some()
315        {
316            tracing::info!(
317                client_id = %client_id,
318                "Client `{client_id}` was unsubscribed."
319            );
320        } else {
321            tracing::warn!(
322                client_id = %client_id,
323                "Can't find client `{client_id}`. Can't unsubscribed."
324            );
325        }
326    }
327
328    #[inline]
329    pub(crate) async fn check_client(
330        &self,
331        cookie_id: CookieId,
332        client_id: &ClientId,
333    ) -> Option<()> {
334        self.client_id_senders
335            .read()
336            .await
337            .get(client_id)
338            .map(ClientSender::cookie_id)
339            .eq(&Some(cookie_id))
340            .then_some(())
341    }
342
343    #[inline]
344    pub(crate) async fn get_client_receiver(&self, client_id: &ClientId) -> Option<ClientReceiver> {
345        self.client_id_senders
346            .read()
347            .await
348            .get(client_id)
349            .map(ClientSender::subscribe)
350    }
351}