Skip to main content

karbon_framework/channel/
channel_registry.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use tokio::sync::{broadcast, RwLock};
4use serde::Serialize;
5use axum::extract::ws::{Message, WebSocket};
6use futures::{SinkExt, StreamExt};
7
8/// Internal message routed through the broadcast channel
9#[derive(Debug, Clone)]
10struct ChannelMessage {
11    channel: String,
12    event: String,
13    payload: String,
14    sender_id: Option<u64>,
15}
16
17/// Wire format for client ↔ server WebSocket messages
18#[derive(Debug, Clone, Serialize, serde::Deserialize)]
19struct WireMessage {
20    channel: String,
21    event: String,
22    #[serde(default)]
23    payload: serde_json::Value,
24}
25
26/// Manages WebSocket channels/rooms with typed messages.
27///
28/// ```ignore
29/// use framework::channel::ChannelRegistry;
30/// use framework::http::ws::websocket_handler_with_state;
31///
32/// let channels = ChannelRegistry::new();
33///
34/// // Mount in router (uses the existing ws helper)
35/// Router::new()
36///     .route("/ws/channels", get(|ws: WebSocketUpgrade, State(ch): State<ChannelRegistry>| {
37///         websocket_handler_with_state(ws, ch, |socket, ch| ch.handle_socket(socket))
38///     }));
39///
40/// // Broadcast from anywhere (controller, job, event handler)
41/// channels.broadcast("chat/room-1", "new_message", &msg).await;
42/// ```
43#[derive(Clone)]
44pub struct ChannelRegistry {
45    tx: broadcast::Sender<ChannelMessage>,
46    rooms: Arc<RwLock<HashMap<String, HashSet<u64>>>>,
47    next_client_id: Arc<std::sync::atomic::AtomicU64>,
48}
49
50impl ChannelRegistry {
51    pub fn new() -> Self {
52        let (tx, _) = broadcast::channel(1024);
53        Self {
54            tx,
55            rooms: Arc::new(RwLock::new(HashMap::new())),
56            next_client_id: Arc::new(std::sync::atomic::AtomicU64::new(1)),
57        }
58    }
59
60    /// Broadcast a typed message to all clients subscribed to a channel
61    pub async fn broadcast<T: Serialize>(&self, channel: &str, event: &str, data: &T) {
62        let payload = match serde_json::to_string(data) {
63            Ok(p) => p,
64            Err(e) => {
65                tracing::warn!(channel = %channel, error = %e, "Failed to serialize channel message");
66                return;
67            }
68        };
69        let _ = self.tx.send(ChannelMessage {
70            channel: channel.to_string(),
71            event: event.to_string(),
72            payload,
73            sender_id: None,
74        });
75    }
76
77    /// Broadcast a raw JSON value to a channel
78    pub async fn broadcast_raw(&self, channel: &str, event: &str, payload: serde_json::Value) {
79        let _ = self.tx.send(ChannelMessage {
80            channel: channel.to_string(),
81            event: event.to_string(),
82            payload: payload.to_string(),
83            sender_id: None,
84        });
85    }
86
87    /// Get the number of connected clients in a channel
88    pub async fn client_count(&self, channel: &str) -> usize {
89        self.rooms.read().await
90            .get(channel)
91            .map(|s| s.len())
92            .unwrap_or(0)
93    }
94
95    /// Get all active channel names
96    pub async fn active_channels(&self) -> Vec<String> {
97        self.rooms.read().await
98            .keys()
99            .cloned()
100            .collect()
101    }
102
103    /// Handle a WebSocket connection with the channel protocol.
104    ///
105    /// The client sends JSON messages:
106    /// - Join:    `{"channel": "chat/room-1", "event": "join", "payload": {}}`
107    /// - Leave:   `{"channel": "chat/room-1", "event": "leave", "payload": {}}`
108    /// - Message: `{"channel": "chat/room-1", "event": "message", "payload": {"text": "hello"}}`
109    pub async fn handle_socket(self, socket: WebSocket) {
110        let client_id = self.next_client_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
111        let mut rx = self.tx.subscribe();
112        let (mut ws_tx, mut ws_rx) = socket.split();
113
114        let subscribed: Arc<RwLock<HashSet<String>>> = Arc::new(RwLock::new(HashSet::new()));
115
116        let sub_read = subscribed.clone();
117
118        let mut send_task = tokio::spawn(async move {
119            while let Ok(msg) = rx.recv().await {
120                if msg.sender_id == Some(client_id) {
121                    continue;
122                }
123
124                let subs = sub_read.read().await;
125                if !subs.contains(&msg.channel) {
126                    continue;
127                }
128                drop(subs);
129
130                let wire = serde_json::json!({
131                    "channel": msg.channel,
132                    "event": msg.event,
133                    "payload": serde_json::from_str::<serde_json::Value>(&msg.payload).unwrap_or_default(),
134                });
135
136                if ws_tx.send(Message::Text(wire.to_string().into())).await.is_err() {
137                    break;
138                }
139            }
140        });
141
142        let sub_write = subscribed.clone();
143        let rooms = self.rooms.clone();
144        let tx = self.tx.clone();
145
146        let mut recv_task = tokio::spawn(async move {
147            while let Some(Ok(msg)) = ws_rx.next().await {
148                let text = match msg {
149                    Message::Text(t) => t.to_string(),
150                    Message::Close(_) => break,
151                    _ => continue,
152                };
153
154                let Ok(wire) = serde_json::from_str::<WireMessage>(&text) else {
155                    continue;
156                };
157
158                match wire.event.as_str() {
159                    "join" => {
160                        sub_write.write().await.insert(wire.channel.clone());
161                        rooms.write().await
162                            .entry(wire.channel.clone())
163                            .or_default()
164                            .insert(client_id);
165                    }
166                    "leave" => {
167                        sub_write.write().await.remove(&wire.channel);
168                        let mut rooms = rooms.write().await;
169                        if let Some(set) = rooms.get_mut(&wire.channel) {
170                            set.remove(&client_id);
171                            if set.is_empty() {
172                                rooms.remove(&wire.channel);
173                            }
174                        }
175                    }
176                    _ => {
177                        let subs = sub_write.read().await;
178                        if subs.contains(&wire.channel) {
179                            let _ = tx.send(ChannelMessage {
180                                channel: wire.channel,
181                                event: wire.event,
182                                payload: wire.payload.to_string(),
183                                sender_id: Some(client_id),
184                            });
185                        }
186                    }
187                }
188            }
189        });
190
191        tokio::select! {
192            _ = &mut send_task => recv_task.abort(),
193            _ = &mut recv_task => send_task.abort(),
194        }
195
196        // Cleanup: remove client from all rooms
197        let subs = subscribed.read().await;
198        let mut rooms = self.rooms.write().await;
199        for channel in subs.iter() {
200            if let Some(set) = rooms.get_mut(channel) {
201                set.remove(&client_id);
202                if set.is_empty() {
203                    rooms.remove(channel);
204                }
205            }
206        }
207    }
208}
209
210impl Default for ChannelRegistry {
211    fn default() -> Self {
212        Self::new()
213    }
214}