Skip to main content

cloudillo_core/
ws_broadcast.rs

1//! WebSocket User Messaging
2//!
3//! Manages direct user-to-user messaging via WebSocket connections.
4//! Supports multiple connections per user (multiple tabs/devices).
5
6use cloudillo_types::types::TnId;
7use cloudillo_types::utils::random_id;
8use serde_json::Value;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::{broadcast, RwLock};
12
13/// A message to send to a user
14#[derive(Clone, Debug)]
15pub struct BroadcastMessage {
16	pub id: String,
17	pub cmd: String,
18	pub data: Value,
19	pub sender: String,
20	pub timestamp: u64,
21}
22
23impl BroadcastMessage {
24	/// Create a new message
25	pub fn new(cmd: impl Into<String>, data: Value, sender: impl Into<String>) -> Self {
26		Self {
27			id: random_id().unwrap_or_default(),
28			cmd: cmd.into(),
29			data,
30			sender: sender.into(),
31			timestamp: now_timestamp(),
32		}
33	}
34}
35
36/// A user connection for direct messaging
37#[derive(Debug)]
38pub struct UserConnection {
39	/// User's id_tag
40	pub id_tag: Box<str>,
41	/// Tenant ID
42	pub tn_id: TnId,
43	/// Unique connection ID (UUID) - supports multiple tabs/devices
44	pub connection_id: Box<str>,
45	/// When this connection was established
46	pub connected_at: u64,
47	/// Sender for this connection
48	sender: broadcast::Sender<BroadcastMessage>,
49}
50
51/// Result of sending a message to a user
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum DeliveryResult {
54	/// Message delivered to N connections
55	Delivered(usize),
56	/// User is not connected (offline)
57	UserOffline,
58}
59
60/// User registry statistics
61#[derive(Debug, Clone)]
62pub struct UserRegistryStats {
63	/// Number of unique online users
64	pub online_users: usize,
65	/// Total number of connections (may be > users if multiple tabs)
66	pub total_connections: usize,
67	/// Users per tenant
68	pub users_per_tenant: HashMap<TnId, usize>,
69}
70
71/// Type alias for the user registry map: TnId -> id_tag -> Vec<UserConnection>
72type UserRegistryMap = HashMap<TnId, HashMap<Box<str>, Vec<UserConnection>>>;
73
74/// Configuration
75#[derive(Clone, Debug)]
76pub struct BroadcastConfig {
77	/// Maximum number of messages to buffer per connection
78	pub buffer_size: usize,
79}
80
81impl Default for BroadcastConfig {
82	fn default() -> Self {
83		Self { buffer_size: 128 }
84	}
85}
86
87/// Manages direct user messaging via WebSocket
88pub struct BroadcastManager {
89	/// User registry for direct messaging
90	users: Arc<RwLock<UserRegistryMap>>,
91	config: BroadcastConfig,
92}
93
94impl BroadcastManager {
95	/// Create a new manager with default config
96	pub fn new() -> Self {
97		Self::with_config(BroadcastConfig::default())
98	}
99
100	/// Create with custom config
101	pub fn with_config(config: BroadcastConfig) -> Self {
102		Self { users: Arc::new(RwLock::new(HashMap::new())), config }
103	}
104
105	/// Register a user connection for direct messaging
106	///
107	/// Returns a receiver for messages targeted at this user.
108	/// The connection_id should be a unique identifier (UUID) for this specific
109	/// connection, allowing multiple connections per user (multiple tabs/devices).
110	pub async fn register_user(
111		&self,
112		tn_id: TnId,
113		id_tag: &str,
114		connection_id: &str,
115	) -> broadcast::Receiver<BroadcastMessage> {
116		let (sender, receiver) = broadcast::channel(self.config.buffer_size);
117
118		let connection = UserConnection {
119			id_tag: id_tag.into(),
120			tn_id,
121			connection_id: connection_id.into(),
122			connected_at: now_timestamp(),
123			sender,
124		};
125
126		let mut users = self.users.write().await;
127		users
128			.entry(tn_id)
129			.or_default()
130			.entry(id_tag.into())
131			.or_default()
132			.push(connection);
133
134		tracing::debug!(tn_id = ?tn_id, id_tag = %id_tag, connection_id = %connection_id, "User registered");
135		receiver
136	}
137
138	/// Unregister a user connection
139	///
140	/// Removes the specific connection identified by connection_id.
141	/// Other connections for the same user (other tabs) are preserved.
142	pub async fn unregister_user(&self, tn_id: TnId, id_tag: &str, connection_id: &str) {
143		let mut users = self.users.write().await;
144
145		if let Some(tenant_users) = users.get_mut(&tn_id) {
146			if let Some(connections) = tenant_users.get_mut(id_tag) {
147				connections.retain(|conn| conn.connection_id.as_ref() != connection_id);
148
149				// Clean up empty entries
150				if connections.is_empty() {
151					tenant_users.remove(id_tag);
152				}
153			}
154
155			// Clean up empty tenant entries
156			if tenant_users.is_empty() {
157				users.remove(&tn_id);
158			}
159		}
160
161		tracing::debug!(tn_id = ?tn_id, id_tag = %id_tag, connection_id = %connection_id, "User unregistered");
162	}
163
164	/// Send a message to a specific user
165	///
166	/// Delivers the message to all connections for the user (multiple tabs/devices).
167	/// Returns `DeliveryResult::Delivered(n)` with the number of connections that
168	/// received the message, or `DeliveryResult::UserOffline` if the user has no
169	/// active connections.
170	pub async fn send_to_user(
171		&self,
172		tn_id: TnId,
173		id_tag: &str,
174		msg: BroadcastMessage,
175	) -> DeliveryResult {
176		let users = self.users.read().await;
177
178		if let Some(tenant_users) = users.get(&tn_id) {
179			if let Some(connections) = tenant_users.get(id_tag) {
180				let mut delivered = 0;
181				for conn in connections {
182					if conn.sender.send(msg.clone()).is_ok() {
183						delivered += 1;
184					}
185				}
186				if delivered > 0 {
187					return DeliveryResult::Delivered(delivered);
188				}
189			}
190		}
191
192		DeliveryResult::UserOffline
193	}
194
195	/// Send a message to all users in a tenant
196	///
197	/// Broadcasts the message to all connections for all users in the tenant.
198	/// Returns the total number of connections that received the message.
199	pub async fn send_to_tenant(&self, tn_id: TnId, msg: BroadcastMessage) -> usize {
200		let users = self.users.read().await;
201
202		let mut delivered = 0;
203		if let Some(tenant_users) = users.get(&tn_id) {
204			for connections in tenant_users.values() {
205				for conn in connections {
206					if conn.sender.send(msg.clone()).is_ok() {
207						delivered += 1;
208					}
209				}
210			}
211		}
212		delivered
213	}
214
215	/// Check if a user is currently online (has at least one connection)
216	pub async fn is_user_online(&self, tn_id: TnId, id_tag: &str) -> bool {
217		let users = self.users.read().await;
218
219		users
220			.get(&tn_id)
221			.and_then(|tenant_users| tenant_users.get(id_tag))
222			.is_some_and(|connections| !connections.is_empty())
223	}
224
225	/// Get list of all online users for a tenant
226	pub async fn online_users(&self, tn_id: TnId) -> Vec<Box<str>> {
227		let users = self.users.read().await;
228
229		users
230			.get(&tn_id)
231			.map(|tenant_users| tenant_users.keys().cloned().collect())
232			.unwrap_or_default()
233	}
234
235	/// Get user registry statistics
236	pub async fn user_stats(&self) -> UserRegistryStats {
237		let users = self.users.read().await;
238
239		let mut online_users = 0;
240		let mut total_connections = 0;
241		let mut users_per_tenant = HashMap::new();
242
243		for (tn_id, tenant_users) in users.iter() {
244			let tenant_user_count = tenant_users.len();
245			online_users += tenant_user_count;
246			users_per_tenant.insert(*tn_id, tenant_user_count);
247
248			for connections in tenant_users.values() {
249				total_connections += connections.len();
250			}
251		}
252
253		UserRegistryStats { online_users, total_connections, users_per_tenant }
254	}
255
256	/// Cleanup disconnected users (users with no active receivers)
257	pub async fn cleanup_users(&self) {
258		let mut users = self.users.write().await;
259
260		for tenant_users in users.values_mut() {
261			for connections in tenant_users.values_mut() {
262				connections.retain(|conn| conn.sender.receiver_count() > 0);
263			}
264			tenant_users.retain(|_, connections| !connections.is_empty());
265		}
266		users.retain(|_, tenant_users| !tenant_users.is_empty());
267	}
268}
269
270impl Default for BroadcastManager {
271	fn default() -> Self {
272		Self::new()
273	}
274}
275
276/// Get current timestamp
277fn now_timestamp() -> u64 {
278	std::time::SystemTime::now()
279		.duration_since(std::time::UNIX_EPOCH)
280		.unwrap_or_default()
281		.as_secs()
282}
283
284#[cfg(test)]
285mod tests {
286	use super::*;
287
288	#[tokio::test]
289	async fn test_register_user() {
290		let manager = BroadcastManager::new();
291		let tn_id = TnId(1);
292
293		let _rx = manager.register_user(tn_id, "alice", "conn-1").await;
294
295		assert!(manager.is_user_online(tn_id, "alice").await);
296		assert!(!manager.is_user_online(tn_id, "bob").await);
297
298		let stats = manager.user_stats().await;
299		assert_eq!(stats.online_users, 1);
300		assert_eq!(stats.total_connections, 1);
301	}
302
303	#[tokio::test]
304	async fn test_multiple_connections_per_user() {
305		let manager = BroadcastManager::new();
306		let tn_id = TnId(1);
307
308		let _rx1 = manager.register_user(tn_id, "alice", "conn-1").await;
309		let _rx2 = manager.register_user(tn_id, "alice", "conn-2").await;
310
311		let stats = manager.user_stats().await;
312		assert_eq!(stats.online_users, 1);
313		assert_eq!(stats.total_connections, 2);
314	}
315
316	#[tokio::test]
317	async fn test_send_to_user() {
318		let manager = BroadcastManager::new();
319		let tn_id = TnId(1);
320
321		let mut rx = manager.register_user(tn_id, "alice", "conn-1").await;
322
323		let msg = BroadcastMessage::new("ACTION", serde_json::json!({ "type": "MSG" }), "system");
324		let result = manager.send_to_user(tn_id, "alice", msg).await;
325
326		assert_eq!(result, DeliveryResult::Delivered(1));
327
328		let received = rx.recv().await.unwrap();
329		assert_eq!(received.cmd, "ACTION");
330	}
331
332	#[tokio::test]
333	async fn test_send_to_offline_user() {
334		let manager = BroadcastManager::new();
335		let tn_id = TnId(1);
336
337		let msg = BroadcastMessage::new("ACTION", serde_json::json!({ "type": "MSG" }), "system");
338		let result = manager.send_to_user(tn_id, "bob", msg).await;
339
340		assert_eq!(result, DeliveryResult::UserOffline);
341	}
342
343	#[tokio::test]
344	async fn test_unregister_user() {
345		let manager = BroadcastManager::new();
346		let tn_id = TnId(1);
347
348		let _rx = manager.register_user(tn_id, "alice", "conn-1").await;
349		assert!(manager.is_user_online(tn_id, "alice").await);
350
351		manager.unregister_user(tn_id, "alice", "conn-1").await;
352		assert!(!manager.is_user_online(tn_id, "alice").await);
353	}
354
355	#[tokio::test]
356	async fn test_multi_tenant_isolation() {
357		let manager = BroadcastManager::new();
358		let tn1 = TnId(1);
359		let tn2 = TnId(2);
360
361		let _rx1 = manager.register_user(tn1, "alice", "conn-1").await;
362		let _rx2 = manager.register_user(tn2, "alice", "conn-2").await;
363
364		assert!(manager.is_user_online(tn1, "alice").await);
365		assert!(manager.is_user_online(tn2, "alice").await);
366
367		let msg = BroadcastMessage::new("test", serde_json::json!({}), "system");
368		let result = manager.send_to_user(tn1, "alice", msg).await;
369		assert_eq!(result, DeliveryResult::Delivered(1));
370
371		let stats = manager.user_stats().await;
372		assert_eq!(stats.online_users, 2);
373	}
374}
375
376// vim: ts=4