Skip to main content

rusher_server/
lib.rs

1use std::{
2    collections::{HashMap, HashSet},
3    fmt::{self, Debug},
4    sync::Arc,
5};
6
7use anyhow::bail;
8use async_trait::async_trait;
9use authentication::check_signature_middleware;
10use axum::{
11    body::Body,
12    extract::{ws::Message, Path, Request, State, WebSocketUpgrade},
13    http::StatusCode,
14    middleware::{self, Next},
15    response::{IntoResponse, Response},
16    routing::{any, get, post},
17    Extension, Json, Router,
18};
19use futures::{SinkExt, StreamExt};
20use rand::Rng;
21use rusher_core::{ChannelName, ConnectionInfo, CustomEvent, ServerEvent, SocketId};
22use rusher_pubsub::{AnyBroker, Broker, Connection};
23use serde::Deserialize;
24use serde_json::{json, Value as JsonValue};
25use tokio::sync::mpsc;
26
27mod authentication;
28mod websocket;
29
30pub use axum::serve;
31use tower_http::trace::{DefaultOnResponse, TraceLayer};
32use tracing::{debug, error, info_span, Instrument, Level};
33use websocket::ConnectionProtocol;
34
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub struct App {
37    pub id: AppId,
38    secret: AppSecret,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Hash)]
42pub struct AppId(String);
43
44impl fmt::Display for AppId {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        write!(f, "{}", self.0)
47    }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
51pub struct AppSecret(String);
52
53impl App {
54    pub fn new(id: impl Into<String>, secret: impl Into<String>) -> Self {
55        Self {
56            id: AppId(id.into()),
57            secret: AppSecret(secret.into()),
58        }
59    }
60}
61
62pub trait IntoAppRepository {
63    type AppRepository: AppRepository + 'static;
64    fn into_app_repository(self) -> Self::AppRepository;
65}
66
67impl<I: IntoIterator<Item = (App, AnyBroker)>> IntoAppRepository for I {
68    type AppRepository = HashMap<AppId, (App, AnyBroker)>;
69
70    fn into_app_repository(self) -> Self::AppRepository {
71        self.into_iter()
72            .map(|(app, broker)| (app.id.clone(), (app, broker)))
73            .collect()
74    }
75}
76
77#[async_trait]
78pub trait AppRepository: Send + Sync {
79    async fn secret_for_app(&self, app_id: &AppId) -> Option<AppSecret>;
80    async fn broker_for_app(&self, app_id: &AppId) -> Option<AnyBroker>;
81}
82
83#[async_trait]
84impl AppRepository for HashMap<AppId, (App, AnyBroker)> {
85    async fn secret_for_app(&self, app_id: &AppId) -> Option<AppSecret> {
86        self.get(app_id).map(|(app, _)| app.secret.clone())
87    }
88
89    async fn broker_for_app(&self, app_id: &AppId) -> Option<AnyBroker> {
90        self.get(app_id).map(|(_, broker)| broker).cloned()
91    }
92}
93
94pub fn app(app_repo: impl IntoAppRepository) -> Router {
95    let app_repo = app_repo.into_app_repository();
96    Router::new()
97        .route("/apps/:app/channels", get(list_channels))
98        .route("/apps/:app/events", post(publish))
99        .route_layer(middleware::from_fn(check_signature_middleware))
100        .route("/app/:app", any(handle_ws))
101        .layer(
102            TraceLayer::new_for_http()
103                .on_response(DefaultOnResponse::default().level(Level::INFO))
104                .make_span_with(|request: &Request<_>| {
105                    info_span!(
106                        "request",
107                        uri = ?request.uri(),
108                        method = ?request.method(),
109                    )
110                }),
111        )
112        .route_layer(middleware::from_fn_with_state(
113            Arc::new(app_repo) as Arc<dyn AppRepository>,
114            broker_middleware,
115        ))
116}
117
118async fn broker_middleware(
119    State(app_repo): State<Arc<dyn AppRepository>>,
120    Path(app): Path<String>,
121    mut request: Request,
122    next: Next,
123) -> Response {
124    let app_id = AppId(app);
125    match (
126        app_repo.secret_for_app(&app_id).await,
127        app_repo.broker_for_app(&app_id).await,
128    ) {
129        (Some(secret), Some(broker)) => {
130            request.extensions_mut().insert(app_id.clone());
131            request.extensions_mut().insert(secret);
132            request.extensions_mut().insert(broker);
133            next.run(request)
134                .instrument(info_span!("app_request", app_id = app_id.0))
135                .await
136        }
137        _ => Response::builder()
138            .status(StatusCode::NOT_FOUND)
139            .body(Body::empty())
140            .unwrap(),
141    }
142}
143
144#[derive(Clone, Debug, Deserialize)]
145pub struct EventPayload {
146    pub name: String,
147    pub data: JsonValue,
148    pub channels: Option<HashSet<ChannelName>>,
149    pub channel: Option<ChannelName>,
150    pub socket_id: Option<SocketId>,
151}
152
153async fn publish(
154    Extension(broker): Extension<AnyBroker>,
155    Json(payload): Json<EventPayload>,
156) -> Result<Json<JsonValue>, StatusCode> {
157    let event = payload.name;
158    let data = payload.data;
159
160    let channels = match (payload.channel, payload.channels) {
161        (Some(channel), Some(mut channels)) => {
162            channels.insert(channel);
163            channels
164        }
165        (Some(channel), None) => HashSet::from_iter([channel]),
166        (None, Some(channels)) => channels,
167        _ => HashSet::new(),
168    };
169
170    for channel in channels {
171        let event = ServerEvent::ChannelEvent(CustomEvent {
172            event: event.clone(),
173            data: data.clone(),
174            channel: channel.clone(),
175            user_id: None,
176        });
177
178        broker
179            .publish(channel.as_ref(), event)
180            .await
181            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
182    }
183
184    Ok(Json(json!({ "ok": true })))
185}
186
187async fn list_channels(Extension(broker): Extension<AnyBroker>) -> impl IntoResponse {
188    let channels = broker
189        .subscriptions()
190        .await
191        .into_iter()
192        .map(|(channel, count)| {
193            (
194                channel,
195                json!({
196                    "subscription_count": count,
197                    "user_count": count,
198                }),
199            )
200        })
201        .collect::<HashMap<String, JsonValue>>();
202
203    Json(channels)
204}
205
206async fn handle_ws(
207    Extension(broker): Extension<AnyBroker>,
208    Extension(AppId(app_id)): Extension<AppId>,
209    Extension(AppSecret(secret)): Extension<AppSecret>,
210    ws: WebSocketUpgrade,
211) -> impl IntoResponse {
212    match broker.connect().await {
213        Ok(mut connection) => Ok(ws.on_upgrade(move |ws| async move {
214            let socket_id: SocketId = rand::thread_rng().gen();
215            let _span = info_span!("websocket", %app_id, %socket_id);
216
217            let (mut write_ws, mut read_ws) = ws.split();
218            let (tx, mut rx) = mpsc::channel(64);
219
220            let write_messages = async move {
221                while let Some(msg) = rx.recv().await {
222                    if let Ok(msg) = serde_json::to_string(&msg) {
223                        if let Err(err) = write_ws.send(Message::Text(msg)).await {
224                            bail!(err)
225                        }
226                    }
227                }
228                anyhow::Ok(())
229            }.instrument(info_span!("websocket_connection_write", %app_id, %socket_id));
230
231            let mut proto = ConnectionProtocol {
232                tx: tx.clone(),
233                app_id: app_id.clone(),
234                secret,
235                socket_id: socket_id.clone(),
236                current_user_id: None,
237            };
238
239            let connection_established = ServerEvent::ConnectionEstablished {
240                data: ConnectionInfo {
241                    socket_id: socket_id.clone(),
242                    activity_timeout: 120,
243                },
244            };
245
246            let read_messages = async move {
247                tx.send(connection_established).await?;
248
249                loop {
250                    tokio::select! {
251                        Ok(msg) = connection.recv() => {
252                            tx.send(msg).await?;
253                        },
254
255                        Some(Ok(msg)) = read_ws.next() => {
256                            match msg {
257                                Message::Text(text) => {
258                                    match serde_json::from_str(&text) {
259                                        Ok(msg) => {
260                                            if let Err(error) = proto.handle_message(&mut connection, msg).await {
261                                                error!(%error);
262                                                bail!(error)
263                                            }
264                                        },
265                                        Err(error) => {
266                                            debug!(msg = "could not decode message", %text, %error);
267                                            continue
268                                        },
269                                    };
270                                }
271                                _ => continue,
272                            }
273                        }
274
275                        else => break,
276                    }
277                }
278                anyhow::Ok(())
279            }.instrument(info_span!("websocket_connection_read", %app_id, %socket_id));
280
281            tokio::select! {
282                _ = write_messages => debug!("Writer finished"),
283                _ = read_messages => debug!("Reader finished"),
284            };
285
286            debug!("Client disconnected");
287        })),
288        Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
289    }
290}