1use std::sync::Arc;
12
13use async_trait::async_trait;
14use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
15use axum::response::IntoResponse;
16use futures::StreamExt;
17use indexmap::IndexMap;
18use parking_lot::RwLock;
19use serde::{Deserialize, Serialize};
20use tokio::sync::broadcast;
21
22#[derive(Debug, Clone)]
23pub struct ChannelBroadcast {
24 pub channel: String,
25 pub event: String,
26 pub data: serde_json::Value,
27}
28
29#[derive(Clone, Default)]
30pub struct BellowsServer {
31 inner: Arc<BellowsInner>,
32}
33
34struct BellowsInner {
35 channels: RwLock<IndexMap<String, broadcast::Sender<ChannelBroadcast>>>,
36}
37
38impl Default for BellowsInner {
39 fn default() -> Self {
40 Self {
41 channels: RwLock::new(IndexMap::new()),
42 }
43 }
44}
45
46impl BellowsServer {
47 pub fn new() -> Self {
48 Self::default()
49 }
50
51 fn channel(&self, name: &str) -> broadcast::Sender<ChannelBroadcast> {
52 if let Some(tx) = self.inner.channels.read().get(name) {
53 return tx.clone();
54 }
55 let (tx, _rx) = broadcast::channel::<ChannelBroadcast>(1024);
56 self.inner
57 .channels
58 .write()
59 .insert(name.to_string(), tx.clone());
60 tx
61 }
62
63 pub fn publish(&self, channel: &str, event: &str, data: serde_json::Value) {
64 let tx = self.channel(channel);
65 let _ = tx.send(ChannelBroadcast {
66 channel: channel.to_string(),
67 event: event.to_string(),
68 data,
69 });
70 }
71
72 pub async fn upgrade(&self, ws: WebSocketUpgrade) -> impl IntoResponse {
73 let server = self.clone();
74 ws.on_upgrade(move |socket| handle_socket(server, socket))
75 }
76}
77
78#[derive(Debug, Deserialize)]
79#[serde(tag = "event")]
80enum ClientMessage {
81 #[serde(rename = "pusher:subscribe")]
82 Subscribe { data: SubscribeData },
83 #[serde(rename = "pusher:unsubscribe")]
84 Unsubscribe {
85 #[allow(dead_code)]
86 data: SubscribeData,
87 },
88}
89
90#[derive(Debug, Deserialize)]
91struct SubscribeData {
92 channel: String,
93}
94
95#[derive(Debug, Serialize)]
96struct ServerMessage<'a> {
97 event: &'a str,
98 channel: Option<&'a str>,
99 data: serde_json::Value,
100}
101
102async fn handle_socket(server: BellowsServer, mut socket: WebSocket) {
103 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<ChannelBroadcast>();
104
105 let _ = socket
106 .send(Message::Text(
107 serde_json::to_string(&ServerMessage {
108 event: "pusher:connection_established",
109 channel: None,
110 data: serde_json::json!({
111 "socket_id": uuid::Uuid::new_v4().to_string(),
112 "activity_timeout": 120,
113 }),
114 })
115 .unwrap(),
116 ))
117 .await;
118
119 let mut subscriptions: Vec<tokio::task::JoinHandle<()>> = Vec::new();
120
121 loop {
122 tokio::select! {
123 msg = socket.next() => {
124 let Some(Ok(msg)) = msg else { break };
125 let Message::Text(text) = msg else { continue };
126 let Ok(client_msg) = serde_json::from_str::<ClientMessage>(&text) else {
127 continue;
128 };
129 match client_msg {
130 ClientMessage::Subscribe { data } => {
131 let tx_clone = tx.clone();
132 let mut sub_rx = server.channel(&data.channel).subscribe();
133 let channel = data.channel.clone();
134 let handle = tokio::spawn(async move {
135 while let Ok(broadcast) = sub_rx.recv().await {
136 let _ = tx_clone.send(broadcast);
137 }
138 drop(channel);
139 });
140 subscriptions.push(handle);
141 let _ = socket.send(Message::Text(serde_json::to_string(&ServerMessage {
142 event: "pusher_internal:subscription_succeeded",
143 channel: Some(&data.channel),
144 data: serde_json::json!({}),
145 }).unwrap())).await;
146 }
147 ClientMessage::Unsubscribe { .. } => {
148 }
150 }
151 }
152 Some(broadcast) = rx.recv() => {
153 let msg = ServerMessage {
154 event: &broadcast.event,
155 channel: Some(&broadcast.channel),
156 data: broadcast.data,
157 };
158 if socket.send(Message::Text(serde_json::to_string(&msg).unwrap())).await.is_err() {
159 break;
160 }
161 }
162 }
163 }
164
165 for h in subscriptions {
166 h.abort();
167 }
168}
169
170#[async_trait]
173pub trait Broadcastable: Send + Sync {
174 fn channel(&self) -> String;
175 fn event_name(&self) -> String;
176 fn payload(&self) -> serde_json::Value;
177}
178
179pub fn broadcast<B: Broadcastable>(server: &BellowsServer, event: B) {
180 server.publish(&event.channel(), &event.event_name(), event.payload());
181}