karbon_framework/channel/
channel_registry.rs1use std::collections::{HashMap, HashSet};
2use std::sync::Arc;
3use tokio::sync::{broadcast, RwLock};
4use serde::Serialize;
5use axum::extract::ws::{Message, WebSocket};
6use futures::{SinkExt, StreamExt};
7
8#[derive(Debug, Clone)]
10struct ChannelMessage {
11 channel: String,
12 event: String,
13 payload: String,
14 sender_id: Option<u64>,
15}
16
17#[derive(Debug, Clone, Serialize, serde::Deserialize)]
19struct WireMessage {
20 channel: String,
21 event: String,
22 #[serde(default)]
23 payload: serde_json::Value,
24}
25
26#[derive(Clone)]
44pub struct ChannelRegistry {
45 tx: broadcast::Sender<ChannelMessage>,
46 rooms: Arc<RwLock<HashMap<String, HashSet<u64>>>>,
47 next_client_id: Arc<std::sync::atomic::AtomicU64>,
48}
49
50impl ChannelRegistry {
51 pub fn new() -> Self {
52 let (tx, _) = broadcast::channel(1024);
53 Self {
54 tx,
55 rooms: Arc::new(RwLock::new(HashMap::new())),
56 next_client_id: Arc::new(std::sync::atomic::AtomicU64::new(1)),
57 }
58 }
59
60 pub async fn broadcast<T: Serialize>(&self, channel: &str, event: &str, data: &T) {
62 let payload = match serde_json::to_string(data) {
63 Ok(p) => p,
64 Err(e) => {
65 tracing::warn!(channel = %channel, error = %e, "Failed to serialize channel message");
66 return;
67 }
68 };
69 let _ = self.tx.send(ChannelMessage {
70 channel: channel.to_string(),
71 event: event.to_string(),
72 payload,
73 sender_id: None,
74 });
75 }
76
77 pub async fn broadcast_raw(&self, channel: &str, event: &str, payload: serde_json::Value) {
79 let _ = self.tx.send(ChannelMessage {
80 channel: channel.to_string(),
81 event: event.to_string(),
82 payload: payload.to_string(),
83 sender_id: None,
84 });
85 }
86
87 pub async fn client_count(&self, channel: &str) -> usize {
89 self.rooms.read().await
90 .get(channel)
91 .map(|s| s.len())
92 .unwrap_or(0)
93 }
94
95 pub async fn active_channels(&self) -> Vec<String> {
97 self.rooms.read().await
98 .keys()
99 .cloned()
100 .collect()
101 }
102
103 pub async fn handle_socket(self, socket: WebSocket) {
110 let client_id = self.next_client_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
111 let mut rx = self.tx.subscribe();
112 let (mut ws_tx, mut ws_rx) = socket.split();
113
114 let subscribed: Arc<RwLock<HashSet<String>>> = Arc::new(RwLock::new(HashSet::new()));
115
116 let sub_read = subscribed.clone();
117
118 let mut send_task = tokio::spawn(async move {
119 while let Ok(msg) = rx.recv().await {
120 if msg.sender_id == Some(client_id) {
121 continue;
122 }
123
124 let subs = sub_read.read().await;
125 if !subs.contains(&msg.channel) {
126 continue;
127 }
128 drop(subs);
129
130 let wire = serde_json::json!({
131 "channel": msg.channel,
132 "event": msg.event,
133 "payload": serde_json::from_str::<serde_json::Value>(&msg.payload).unwrap_or_default(),
134 });
135
136 if ws_tx.send(Message::Text(wire.to_string().into())).await.is_err() {
137 break;
138 }
139 }
140 });
141
142 let sub_write = subscribed.clone();
143 let rooms = self.rooms.clone();
144 let tx = self.tx.clone();
145
146 let mut recv_task = tokio::spawn(async move {
147 while let Some(Ok(msg)) = ws_rx.next().await {
148 let text = match msg {
149 Message::Text(t) => t.to_string(),
150 Message::Close(_) => break,
151 _ => continue,
152 };
153
154 let Ok(wire) = serde_json::from_str::<WireMessage>(&text) else {
155 continue;
156 };
157
158 match wire.event.as_str() {
159 "join" => {
160 sub_write.write().await.insert(wire.channel.clone());
161 rooms.write().await
162 .entry(wire.channel.clone())
163 .or_default()
164 .insert(client_id);
165 }
166 "leave" => {
167 sub_write.write().await.remove(&wire.channel);
168 let mut rooms = rooms.write().await;
169 if let Some(set) = rooms.get_mut(&wire.channel) {
170 set.remove(&client_id);
171 if set.is_empty() {
172 rooms.remove(&wire.channel);
173 }
174 }
175 }
176 _ => {
177 let subs = sub_write.read().await;
178 if subs.contains(&wire.channel) {
179 let _ = tx.send(ChannelMessage {
180 channel: wire.channel,
181 event: wire.event,
182 payload: wire.payload.to_string(),
183 sender_id: Some(client_id),
184 });
185 }
186 }
187 }
188 }
189 });
190
191 tokio::select! {
192 _ = &mut send_task => recv_task.abort(),
193 _ = &mut recv_task => send_task.abort(),
194 }
195
196 let subs = subscribed.read().await;
198 let mut rooms = self.rooms.write().await;
199 for channel in subs.iter() {
200 if let Some(set) = rooms.get_mut(channel) {
201 set.remove(&client_id);
202 if set.is_empty() {
203 rooms.remove(channel);
204 }
205 }
206 }
207 }
208}
209
210impl Default for ChannelRegistry {
211 fn default() -> Self {
212 Self::new()
213 }
214}