rust-oxide-realtime 0.2.0

Reusable realtime transport primitives for Axum servers and Rust websocket clients
Documentation
use std::time::Duration;

use axum::extract::ws::{Message, WebSocket};
use chrono::Utc;
use futures_util::{SinkExt, StreamExt};
use tokio::sync::mpsc;
use tokio::time::{Instant, MissedTickBehavior, interval};

use crate::protocol::{ClientFrame, ServerFrame};

use super::{
    RealtimeConfig, SessionAuth,
    hub::HubCommand,
    types::{ChannelName, ConnectionId, ConnectionMeta, DisconnectReason},
};

pub async fn run_socket_session(
    socket: WebSocket,
    auth: SessionAuth,
    hub_tx: mpsc::Sender<HubCommand>,
    cfg: RealtimeConfig,
) {
    let conn_id = ConnectionId::new();
    let user_id = auth.user_id.clone();
    let (outbound_tx, mut outbound_rx) = mpsc::channel(cfg.outbound_queue_size);

    let meta = ConnectionMeta {
        id: conn_id,
        user_id: auth.user_id,
        roles: auth.roles,
        joined_at_unix: Utc::now().timestamp(),
    };

    if hub_tx
        .send(HubCommand::Register { meta, outbound_tx })
        .await
        .is_err()
    {
        return;
    }

    tracing::debug!(conn_id = %conn_id, user_id = %user_id, "realtime session connected");

    let (mut ws_sender, mut ws_receiver) = socket.split();
    let mut heartbeat = interval(Duration::from_secs(cfg.heartbeat_interval_secs));
    heartbeat.set_missed_tick_behavior(MissedTickBehavior::Delay);

    let idle_timeout = Duration::from_secs(cfg.idle_timeout_secs);
    let mut last_activity = Instant::now();

    let disconnect_reason = loop {
        tokio::select! {
            outbound = outbound_rx.recv() => {
                let Some(frame) = outbound else {
                    break DisconnectReason::HubUnavailable;
                };

                let Ok(payload) = serde_json::to_string(&frame) else {
                    break DisconnectReason::ProtocolError;
                };

                if ws_sender.send(Message::Text(payload.into())).await.is_err() {
                    break DisconnectReason::SocketError;
                }
            }
            incoming = ws_receiver.next() => {
                let Some(incoming) = incoming else {
                    break DisconnectReason::ClientClosed;
                };

                match incoming {
                    Ok(Message::Text(text)) => {
                        last_activity = Instant::now();
                        if text.len() > cfg.max_message_bytes {
                            let _ = send_direct_error(
                                &mut ws_sender,
                                "message_too_large",
                                "Message exceeds realtime.max_message_bytes",
                            ).await;
                            continue;
                        }

                        let frame = match serde_json::from_str::<ClientFrame>(&text) {
                            Ok(frame) => frame,
                            Err(_) => {
                                let _ = send_direct_error(
                                    &mut ws_sender,
                                    "invalid_payload",
                                    "Invalid websocket payload",
                                )
                                .await;
                                continue;
                            }
                        };

                        if dispatch_client_frame(conn_id, frame, &hub_tx, &mut ws_sender)
                            .await
                            .is_err()
                        {
                            break DisconnectReason::HubUnavailable;
                        }
                    }
                    Ok(Message::Binary(_)) => {
                        let _ = send_direct_error(
                            &mut ws_sender,
                            "invalid_payload",
                            "Binary websocket payloads are not supported",
                        )
                        .await;
                    }
                    Ok(Message::Ping(payload)) => {
                        last_activity = Instant::now();
                        if ws_sender.send(Message::Pong(payload)).await.is_err() {
                            break DisconnectReason::SocketError;
                        }
                    }
                    Ok(Message::Pong(_)) => {
                        last_activity = Instant::now();
                    }
                    Ok(Message::Close(_)) => {
                        break DisconnectReason::ClientClosed;
                    }
                    Err(_) => {
                        break DisconnectReason::SocketError;
                    }
                }
            }
            _ = heartbeat.tick() => {
                if last_activity.elapsed() > idle_timeout {
                    break DisconnectReason::IdleTimeout;
                }

                if ws_sender
                    .send(Message::Ping(Vec::new().into()))
                    .await
                    .is_err()
                {
                    break DisconnectReason::SocketError;
                }
            }
        }
    };

    tracing::debug!(
        conn_id = %conn_id,
        user_id = %user_id,
        reason = ?disconnect_reason,
        "realtime session closing"
    );

    let _ = hub_tx
        .send(HubCommand::Unregister {
            conn_id,
            reason: disconnect_reason,
        })
        .await;
}

async fn dispatch_client_frame(
    conn_id: ConnectionId,
    frame: ClientFrame,
    hub_tx: &mpsc::Sender<HubCommand>,
    ws_sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
) -> Result<(), ()> {
    let command = match frame {
        ClientFrame::ChannelJoin { id, channel, .. } => {
            let channel = match ChannelName::parse(&channel) {
                Ok(channel) => channel,
                Err(err) => {
                    let message = err.message().to_string();
                    let _ = send_direct_error(ws_sender, "invalid_channel", &message).await;
                    return Ok(());
                }
            };
            HubCommand::Join {
                conn_id,
                channel,
                req_id: id,
            }
        }
        ClientFrame::ChannelLeave { id, channel, .. } => {
            let channel = match ChannelName::parse(&channel) {
                Ok(channel) => channel,
                Err(err) => {
                    let message = err.message().to_string();
                    let _ = send_direct_error(ws_sender, "invalid_channel", &message).await;
                    return Ok(());
                }
            };
            HubCommand::Leave {
                conn_id,
                channel,
                req_id: id,
            }
        }
        ClientFrame::ChannelEmit {
            id,
            channel,
            event,
            data,
            ..
        } => {
            let channel = match ChannelName::parse(&channel) {
                Ok(channel) => channel,
                Err(err) => {
                    let message = err.message().to_string();
                    let _ = send_direct_error(ws_sender, "invalid_channel", &message).await;
                    return Ok(());
                }
            };
            HubCommand::Emit {
                conn_id,
                channel,
                event,
                payload: data,
                req_id: id,
            }
        }
        ClientFrame::Ping { id, .. } => HubCommand::Ping {
            conn_id,
            req_id: id,
        },
    };

    hub_tx.send(command).await.map_err(|_| ())
}

async fn send_direct_error(
    ws_sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
    code: &str,
    message: &str,
) -> Result<(), ()> {
    let frame = ServerFrame::error(code, message);
    let payload = serde_json::to_string(&frame).map_err(|_| ())?;
    ws_sender
        .send(Message::Text(payload.into()))
        .await
        .map_err(|_| ())
}