Skip to main content

cloudillo_core/
ws_broadcast.rs

1// SPDX-FileCopyrightText: Szilárd Hajba
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
4//! WebSocket User Messaging
5//!
6//! Manages direct user-to-user messaging via WebSocket connections.
7//! Supports multiple connections per user (multiple tabs/devices).
8
9use cloudillo_types::types::TnId;
10use cloudillo_types::utils::random_id;
11use serde_json::Value;
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::{broadcast, RwLock};
15
16/// A message to send to a user
17#[derive(Clone, Debug)]
18pub struct BroadcastMessage {
19	pub id: String,
20	pub cmd: String,
21	pub data: Value,
22	pub sender: String,
23	pub timestamp: u64,
24}
25
26impl BroadcastMessage {
27	/// Create a new message
28	pub fn new(cmd: impl Into<String>, data: Value, sender: impl Into<String>) -> Self {
29		Self {
30			id: random_id().unwrap_or_default(),
31			cmd: cmd.into(),
32			data,
33			sender: sender.into(),
34			timestamp: now_timestamp(),
35		}
36	}
37}
38
39/// A user connection for direct messaging
40#[derive(Debug)]
41pub struct UserConnection {
42	/// User's id_tag
43	pub id_tag: Box<str>,
44	/// Tenant ID
45	pub tn_id: TnId,
46	/// Unique connection ID (UUID) - supports multiple tabs/devices
47	pub connection_id: Box<str>,
48	/// When this connection was established
49	pub connected_at: u64,
50	/// Sender for this connection
51	sender: broadcast::Sender<BroadcastMessage>,
52}
53
54/// Result of sending a message to a user
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum DeliveryResult {
57	/// Message delivered to N connections
58	Delivered(usize),
59	/// User is not connected (offline)
60	UserOffline,
61}
62
63/// User registry statistics
64#[derive(Debug, Clone)]
65pub struct UserRegistryStats {
66	/// Number of unique online users
67	pub online_users: usize,
68	/// Total number of connections (may be > users if multiple tabs)
69	pub total_connections: usize,
70	/// Users per tenant
71	pub users_per_tenant: HashMap<TnId, usize>,
72}
73
74/// Type alias for the user registry map: TnId -> id_tag -> Vec<UserConnection>
75type UserRegistryMap = HashMap<TnId, HashMap<Box<str>, Vec<UserConnection>>>;
76
77/// Configuration
78#[derive(Clone, Debug)]
79pub struct BroadcastConfig {
80	/// Maximum number of messages to buffer per connection
81	pub buffer_size: usize,
82}
83
84impl Default for BroadcastConfig {
85	fn default() -> Self {
86		// Enough for typical WebSocket message bursts without excessive memory
87		Self { buffer_size: 128 }
88	}
89}
90
91/// Manages direct user messaging via WebSocket
92pub struct BroadcastManager {
93	/// User registry for direct messaging
94	users: Arc<RwLock<UserRegistryMap>>,
95	config: BroadcastConfig,
96}
97
98impl BroadcastManager {
99	/// Create a new manager with default config
100	pub fn new() -> Self {
101		Self::with_config(BroadcastConfig::default())
102	}
103
104	/// Create with custom config
105	pub fn with_config(config: BroadcastConfig) -> Self {
106		Self { users: Arc::new(RwLock::new(HashMap::new())), config }
107	}
108
109	/// Register a user connection for direct messaging
110	///
111	/// Returns a receiver for messages targeted at this user.
112	/// The connection_id should be a unique identifier (UUID) for this specific
113	/// connection, allowing multiple connections per user (multiple tabs/devices).
114	pub async fn register_user(
115		&self,
116		tn_id: TnId,
117		id_tag: &str,
118		connection_id: &str,
119	) -> broadcast::Receiver<BroadcastMessage> {
120		let (sender, receiver) = broadcast::channel(self.config.buffer_size);
121
122		let connection = UserConnection {
123			id_tag: id_tag.into(),
124			tn_id,
125			connection_id: connection_id.into(),
126			connected_at: now_timestamp(),
127			sender,
128		};
129
130		let mut users = self.users.write().await;
131		users
132			.entry(tn_id)
133			.or_default()
134			.entry(id_tag.into())
135			.or_default()
136			.push(connection);
137
138		tracing::debug!(tn_id = ?tn_id, id_tag = %id_tag, connection_id = %connection_id, "User registered");
139		receiver
140	}
141
142	/// Unregister a user connection
143	///
144	/// Removes the specific connection identified by connection_id.
145	/// Other connections for the same user (other tabs) are preserved.
146	pub async fn unregister_user(&self, tn_id: TnId, id_tag: &str, connection_id: &str) {
147		let mut users = self.users.write().await;
148
149		if let Some(tenant_users) = users.get_mut(&tn_id) {
150			if let Some(connections) = tenant_users.get_mut(id_tag) {
151				connections.retain(|conn| conn.connection_id.as_ref() != connection_id);
152
153				// Clean up empty entries
154				if connections.is_empty() {
155					tenant_users.remove(id_tag);
156				}
157			}
158
159			// Clean up empty tenant entries
160			if tenant_users.is_empty() {
161				users.remove(&tn_id);
162			}
163		}
164
165		tracing::debug!(tn_id = ?tn_id, id_tag = %id_tag, connection_id = %connection_id, "User unregistered");
166	}
167
168	/// Send a message to a specific user
169	///
170	/// Delivers the message to all connections for the user (multiple tabs/devices).
171	/// Returns `DeliveryResult::Delivered(n)` with the number of connections that
172	/// received the message, or `DeliveryResult::UserOffline` if the user has no
173	/// active connections.
174	pub async fn send_to_user(
175		&self,
176		tn_id: TnId,
177		id_tag: &str,
178		msg: BroadcastMessage,
179	) -> DeliveryResult {
180		let users = self.users.read().await;
181
182		if let Some(tenant_users) = users.get(&tn_id) {
183			if let Some(connections) = tenant_users.get(id_tag) {
184				let mut delivered = 0;
185				for conn in connections {
186					if conn.sender.send(msg.clone()).is_ok() {
187						delivered += 1;
188					}
189				}
190				if delivered > 0 {
191					return DeliveryResult::Delivered(delivered);
192				}
193			}
194		}
195
196		DeliveryResult::UserOffline
197	}
198
199	/// Send a message to all users in a tenant
200	///
201	/// Broadcasts the message to all connections for all users in the tenant.
202	/// Returns the total number of connections that received the message.
203	pub async fn send_to_tenant(&self, tn_id: TnId, msg: BroadcastMessage) -> usize {
204		let users = self.users.read().await;
205
206		let mut delivered = 0;
207		if let Some(tenant_users) = users.get(&tn_id) {
208			for connections in tenant_users.values() {
209				for conn in connections {
210					if conn.sender.send(msg.clone()).is_ok() {
211						delivered += 1;
212					}
213				}
214			}
215		}
216		delivered
217	}
218
219	/// Check if a user is currently online (has at least one connection)
220	pub async fn is_user_online(&self, tn_id: TnId, id_tag: &str) -> bool {
221		let users = self.users.read().await;
222
223		users
224			.get(&tn_id)
225			.and_then(|tenant_users| tenant_users.get(id_tag))
226			.is_some_and(|connections| !connections.is_empty())
227	}
228
229	/// Get list of all online users for a tenant
230	pub async fn online_users(&self, tn_id: TnId) -> Vec<Box<str>> {
231		let users = self.users.read().await;
232
233		users
234			.get(&tn_id)
235			.map(|tenant_users| tenant_users.keys().cloned().collect())
236			.unwrap_or_default()
237	}
238
239	/// Get user registry statistics
240	pub async fn user_stats(&self) -> UserRegistryStats {
241		let users = self.users.read().await;
242
243		let mut online_users = 0;
244		let mut total_connections = 0;
245		let mut users_per_tenant = HashMap::new();
246
247		for (tn_id, tenant_users) in users.iter() {
248			let tenant_user_count = tenant_users.len();
249			online_users += tenant_user_count;
250			users_per_tenant.insert(*tn_id, tenant_user_count);
251
252			for connections in tenant_users.values() {
253				total_connections += connections.len();
254			}
255		}
256
257		UserRegistryStats { online_users, total_connections, users_per_tenant }
258	}
259
260	/// Cleanup disconnected users (users with no active receivers)
261	pub async fn cleanup_users(&self) {
262		let mut users = self.users.write().await;
263
264		for tenant_users in users.values_mut() {
265			for connections in tenant_users.values_mut() {
266				connections.retain(|conn| conn.sender.receiver_count() > 0);
267			}
268			tenant_users.retain(|_, connections| !connections.is_empty());
269		}
270		users.retain(|_, tenant_users| !tenant_users.is_empty());
271	}
272}
273
274impl Default for BroadcastManager {
275	fn default() -> Self {
276		Self::new()
277	}
278}
279
280/// Get current timestamp
281fn now_timestamp() -> u64 {
282	std::time::SystemTime::now()
283		.duration_since(std::time::UNIX_EPOCH)
284		.unwrap_or_default()
285		.as_secs()
286}
287
288#[cfg(test)]
289mod tests {
290	use super::*;
291
292	#[tokio::test]
293	async fn test_register_user() {
294		let manager = BroadcastManager::new();
295		let tn_id = TnId(1);
296
297		let _rx = manager.register_user(tn_id, "alice", "conn-1").await;
298
299		assert!(manager.is_user_online(tn_id, "alice").await);
300		assert!(!manager.is_user_online(tn_id, "bob").await);
301
302		let stats = manager.user_stats().await;
303		assert_eq!(stats.online_users, 1);
304		assert_eq!(stats.total_connections, 1);
305	}
306
307	#[tokio::test]
308	async fn test_multiple_connections_per_user() {
309		let manager = BroadcastManager::new();
310		let tn_id = TnId(1);
311
312		let _rx1 = manager.register_user(tn_id, "alice", "conn-1").await;
313		let _rx2 = manager.register_user(tn_id, "alice", "conn-2").await;
314
315		let stats = manager.user_stats().await;
316		assert_eq!(stats.online_users, 1);
317		assert_eq!(stats.total_connections, 2);
318	}
319
320	#[tokio::test]
321	async fn test_send_to_user() {
322		let manager = BroadcastManager::new();
323		let tn_id = TnId(1);
324
325		let mut rx = manager.register_user(tn_id, "alice", "conn-1").await;
326
327		let msg = BroadcastMessage::new("ACTION", serde_json::json!({ "type": "MSG" }), "system");
328		let result = manager.send_to_user(tn_id, "alice", msg).await;
329
330		assert_eq!(result, DeliveryResult::Delivered(1));
331
332		let received = rx.recv().await.unwrap();
333		assert_eq!(received.cmd, "ACTION");
334	}
335
336	#[tokio::test]
337	async fn test_send_to_offline_user() {
338		let manager = BroadcastManager::new();
339		let tn_id = TnId(1);
340
341		let msg = BroadcastMessage::new("ACTION", serde_json::json!({ "type": "MSG" }), "system");
342		let result = manager.send_to_user(tn_id, "bob", msg).await;
343
344		assert_eq!(result, DeliveryResult::UserOffline);
345	}
346
347	#[tokio::test]
348	async fn test_unregister_user() {
349		let manager = BroadcastManager::new();
350		let tn_id = TnId(1);
351
352		let _rx = manager.register_user(tn_id, "alice", "conn-1").await;
353		assert!(manager.is_user_online(tn_id, "alice").await);
354
355		manager.unregister_user(tn_id, "alice", "conn-1").await;
356		assert!(!manager.is_user_online(tn_id, "alice").await);
357	}
358
359	#[tokio::test]
360	async fn test_multi_tenant_isolation() {
361		let manager = BroadcastManager::new();
362		let tn1 = TnId(1);
363		let tn2 = TnId(2);
364
365		let _rx1 = manager.register_user(tn1, "alice", "conn-1").await;
366		let _rx2 = manager.register_user(tn2, "alice", "conn-2").await;
367
368		assert!(manager.is_user_online(tn1, "alice").await);
369		assert!(manager.is_user_online(tn2, "alice").await);
370
371		let msg = BroadcastMessage::new("test", serde_json::json!({}), "system");
372		let result = manager.send_to_user(tn1, "alice", msg).await;
373		assert_eq!(result, DeliveryResult::Delivered(1));
374
375		let stats = manager.user_stats().await;
376		assert_eq!(stats.online_users, 2);
377	}
378}
379
380// vim: ts=4