fhtmx_actix/
sse.rs

1//! # Server sent events
2
3use actix_web::{Responder, web};
4use actix_web_lab::sse::{Data, Event};
5use dashmap::DashMap;
6use futures::{StreamExt, stream};
7use serde::{Deserialize, Serialize, de::DeserializeOwned};
8use std::{marker::PhantomData, sync::Arc, time::Duration};
9use tokio::sync::mpsc;
10use uuid::Uuid;
11
12// TODO: clean removed sessions
13
14/// Setups Server sent event state and routes
15///
16/// # Example
17///
18/// ```rust,ignore
19/// use actix_web::App;
20/// use fhtmx_actix::sse::SseSetup;
21///
22/// let sse_setup = SseSetup::new();
23/// let sse_data = sse_setup.state();
24/// App::new()
25///     .configure(|cfg| sse_setup.setup_route("/sse", cfg))
26///     .app_data(sse_data);
27/// ```
28#[derive(Clone, Copy)]
29pub struct SseSetup<T> {
30    session_data: PhantomData<T>,
31}
32
33/// Setup sse without session data
34#[derive(Debug, Clone, Copy, PartialEq, Deserialize)]
35pub struct FhtmxUiNoSessionData;
36
37impl SseSetup<()> {
38    pub fn new_with_data<T>() -> SseSetup<T> {
39        SseSetup {
40            session_data: PhantomData,
41        }
42    }
43}
44
45impl Default for SseSetup<FhtmxUiNoSessionData> {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl SseSetup<FhtmxUiNoSessionData> {
52    pub fn new() -> Self {
53        SseSetup {
54            session_data: PhantomData,
55        }
56    }
57}
58
59impl<T> SseSetup<T>
60where
61    T: DeserializeOwned + 'static,
62{
63    /// Gets a `SseState` instance for you to add it to your app
64    #[must_use]
65    pub fn state_data(&self) -> web::Data<SseState<T>> {
66        web::Data::new(SseState::default())
67    }
68
69    /// Setups the sse route
70    pub fn setup_route(&self, path: &str, cfg: &mut web::ServiceConfig) {
71        cfg.route(path, web::get().to(sse_handler::<T>));
72    }
73}
74
75/// SSE state
76pub struct SseState<T> {
77    pub sessions: Arc<DashMap<Uuid, SseSession<T>>>,
78}
79
80impl<T> Default for SseState<T> {
81    fn default() -> Self {
82        Self {
83            sessions: Arc::new(DashMap::new()),
84        }
85    }
86}
87
88impl<T> SseState<T> {
89    pub fn add_session(
90        &self,
91        id: Uuid,
92        data: Option<T>,
93        sender: mpsc::Sender<Event>,
94    ) -> Option<SseSession<T>> {
95        let session = SseSession { data, sender };
96        self.sessions.insert(id, session)
97    }
98
99    pub fn remove_session(&self, id: Uuid) -> Option<(Uuid, SseSession<T>)> {
100        self.sessions.remove(&id)
101    }
102
103    /// Sends a message to session id
104    pub async fn send_message(&self, id: Uuid, data: Data) -> Option<()> {
105        let sender = self.sessions.get(&id)?.sender.clone();
106        if sender.send(Event::Data(data)).await.is_err() {
107            // Channel is closed so we remove the session
108            self.remove_session(id);
109        }
110        Some(())
111    }
112
113    /// Broadcast a message to all sessions and returns the number of sent messages
114    pub async fn broadcast(&self, data: Data) -> usize {
115        let senders = self
116            .sessions
117            .iter()
118            .map(|o| o.value().sender.clone())
119            .collect::<Vec<_>>();
120        sse_broadcast(senders, data).await
121    }
122
123    /// Broadcast a message to all sessions but one id and returns the number of sent messages
124    pub async fn broadcast_all_but(&self, id: Uuid, data: Data) -> usize {
125        let senders = self
126            .sessions
127            .iter()
128            .filter_map(|o| {
129                if *o.key() == id {
130                    None
131                } else {
132                    Some(o.value().sender.clone())
133                }
134            })
135            .collect::<Vec<_>>();
136        sse_broadcast(senders, data).await
137    }
138}
139
140pub async fn sse_broadcast(senders: Vec<mpsc::Sender<Event>>, data: Data) -> usize {
141    stream::iter(senders)
142        .filter_map(|o| {
143            let data = data.clone();
144            async move { o.send(Event::Data(data)).await.ok() }
145        })
146        .count()
147        .await
148}
149
150pub struct SseSession<T> {
151    pub data: Option<T>,
152    pub sender: mpsc::Sender<Event>,
153}
154
155/// Identifier for the session
156#[derive(Clone, Debug, Serialize, Deserialize)]
157pub struct SseQueryParams<T> {
158    pub id: Uuid,
159    pub data: Option<T>,
160}
161
162/// Route to handle web sockets
163#[tracing::instrument(skip_all)]
164pub async fn sse_handler<T>(
165    web::Query(query): web::Query<SseQueryParams<T>>,
166    state: web::Data<SseState<T>>,
167) -> impl Responder {
168    let (tx, rx) = mpsc::channel(8);
169    state.add_session(query.id, query.data, tx);
170    actix_web_lab::sse::Sse::from_infallible_receiver(rx).with_keep_alive(Duration::from_secs(3))
171}