1use std::collections::HashMap;
12use std::sync::Arc;
13
14use async_trait::async_trait;
15use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
16use axum::response::IntoResponse;
17use futures::StreamExt;
18use indexmap::IndexMap;
19use parking_lot::RwLock;
20use serde::{Deserialize, Serialize};
21use tokio::sync::broadcast;
22
23#[derive(Debug, Clone)]
24pub struct ChannelBroadcast {
25 pub channel: String,
26 pub event: String,
27 pub data: serde_json::Value,
28}
29
30#[derive(Clone, Default)]
31pub struct BellowsServer {
32 inner: Arc<BellowsInner>,
33}
34
35struct BellowsInner {
36 channels: RwLock<IndexMap<String, broadcast::Sender<ChannelBroadcast>>>,
37}
38
39impl Default for BellowsInner {
40 fn default() -> Self {
41 Self {
42 channels: RwLock::new(IndexMap::new()),
43 }
44 }
45}
46
47impl BellowsServer {
48 pub fn new() -> Self {
49 Self::default()
50 }
51
52 fn channel(&self, name: &str) -> broadcast::Sender<ChannelBroadcast> {
53 if let Some(tx) = self.inner.channels.read().get(name) {
54 return tx.clone();
55 }
56 let (tx, _rx) = broadcast::channel::<ChannelBroadcast>(1024);
57 self.inner
58 .channels
59 .write()
60 .insert(name.to_string(), tx.clone());
61 tx
62 }
63
64 pub fn publish(&self, channel: &str, event: &str, data: serde_json::Value) {
65 let tx = self.channel(channel);
66 let _ = tx.send(ChannelBroadcast {
67 channel: channel.to_string(),
68 event: event.to_string(),
69 data,
70 });
71 }
72
73 pub fn subscriber_count(&self, channel: &str) -> usize {
77 self.inner
78 .channels
79 .read()
80 .get(channel)
81 .map(|tx| tx.receiver_count())
82 .unwrap_or(0)
83 }
84
85 pub async fn upgrade(&self, ws: WebSocketUpgrade) -> impl IntoResponse {
86 let server = self.clone();
87 ws.on_upgrade(move |socket| handle_socket(server, socket))
88 }
89}
90
91#[derive(Debug, Deserialize)]
92#[serde(tag = "event")]
93enum ClientMessage {
94 #[serde(rename = "pusher:subscribe")]
95 Subscribe { data: SubscribeData },
96 #[serde(rename = "pusher:unsubscribe")]
97 Unsubscribe { data: SubscribeData },
98}
99
100#[derive(Debug, Deserialize)]
101struct SubscribeData {
102 channel: String,
103}
104
105#[derive(Debug, Serialize)]
106struct ServerMessage<'a> {
107 event: &'a str,
108 channel: Option<&'a str>,
109 data: serde_json::Value,
110}
111
112async fn handle_socket(server: BellowsServer, mut socket: WebSocket) {
113 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<ChannelBroadcast>();
114
115 let _ = socket
116 .send(Message::Text(
117 serde_json::to_string(&ServerMessage {
118 event: "pusher:connection_established",
119 channel: None,
120 data: serde_json::json!({
121 "socket_id": uuid::Uuid::new_v4().to_string(),
122 "activity_timeout": 120,
123 }),
124 })
125 .unwrap(),
126 ))
127 .await;
128
129 let mut subscriptions: HashMap<String, tokio::task::JoinHandle<()>> = HashMap::new();
134
135 loop {
136 tokio::select! {
137 msg = socket.next() => {
138 let Some(Ok(msg)) = msg else { break };
139 let Message::Text(text) = msg else { continue };
140 let Ok(client_msg) = serde_json::from_str::<ClientMessage>(&text) else {
141 continue;
142 };
143 match client_msg {
144 ClientMessage::Subscribe { data } => {
145 if let Some(prior) = subscriptions.remove(&data.channel) {
149 prior.abort();
150 }
151 let tx_clone = tx.clone();
152 let mut sub_rx = server.channel(&data.channel).subscribe();
153 let channel = data.channel.clone();
154 let handle = tokio::spawn(async move {
155 while let Ok(broadcast) = sub_rx.recv().await {
156 let _ = tx_clone.send(broadcast);
157 }
158 drop(channel);
159 });
160 subscriptions.insert(data.channel.clone(), handle);
161 let _ = socket.send(Message::Text(serde_json::to_string(&ServerMessage {
162 event: "pusher_internal:subscription_succeeded",
163 channel: Some(&data.channel),
164 data: serde_json::json!({}),
165 }).unwrap())).await;
166 }
167 ClientMessage::Unsubscribe { data } => {
168 if let Some(handle) = subscriptions.remove(&data.channel) {
169 handle.abort();
170 tracing::trace!(channel = %data.channel, "bellows: unsubscribed");
174 }
175 }
176 }
177 }
178 Some(broadcast) = rx.recv() => {
179 let msg = ServerMessage {
180 event: &broadcast.event,
181 channel: Some(&broadcast.channel),
182 data: broadcast.data,
183 };
184 if socket.send(Message::Text(serde_json::to_string(&msg).unwrap())).await.is_err() {
185 break;
186 }
187 }
188 }
189 }
190
191 for (_channel, h) in subscriptions.drain() {
195 h.abort();
196 }
197}
198
199#[async_trait]
202pub trait Broadcastable: Send + Sync {
203 fn channel(&self) -> String;
204 fn event_name(&self) -> String;
205 fn payload(&self) -> serde_json::Value;
206}
207
208pub fn broadcast<B: Broadcastable>(server: &BellowsServer, event: B) {
209 server.publish(&event.channel(), &event.event_name(), event.payload());
210}