Skip to main content

hocuspocus_rs/
lib.rs

1//! Hocuspocus-RS
2//!
3//! A Rust implementation of the Hocuspocus protocol (Yjs over WebSockets).
4//! Provides a handler for Yjs documents that follows the Hocuspocus V2 protocol structure.
5
6pub mod sync;
7#[cfg(feature = "sqlite")]
8pub mod db;
9
10pub use sync::{DocHandler, MSG_AUTH, MSG_AWARENESS, MSG_QUERY_AWARENESS, MSG_SYNC};
11#[cfg(feature = "sqlite")]
12pub use db::Database;
13
14#[cfg(feature = "server")]
15use axum::{
16    extract::{
17        ws::{Message, WebSocket, WebSocketUpgrade},
18        State,
19    },
20    response::Response,
21    routing::get,
22    Router,
23};
24#[cfg(feature = "server")]
25use dashmap::DashMap;
26#[cfg(feature = "server")]
27use futures_util::{
28    stream::{SplitSink, SplitStream},
29    SinkExt, StreamExt,
30};
31#[cfg(feature = "server")]
32use std::sync::Arc;
33
34/// Application state shared across WebSocket connections
35#[cfg(feature = "server")]
36pub struct AppState {
37    pub rooms: DashMap<String, Arc<DocHandler>>,
38    #[cfg(feature = "sqlite")]
39    pub db: Database,
40}
41
42#[cfg(feature = "server")]
43impl AppState {
44    #[cfg(feature = "sqlite")]
45    pub fn new(db: Database) -> Self {
46        Self {
47            rooms: DashMap::new(),
48            db,
49        }
50    }
51
52    #[cfg(not(feature = "sqlite"))]
53    pub fn new() -> Self {
54        Self {
55            rooms: DashMap::new(),
56        }
57    }
58
59    /// Get or create a document handler for a room
60    pub fn get_or_create_handler(&self, room_name: &str) -> Arc<DocHandler> {
61        self.rooms
62            .entry(room_name.to_string())
63            .or_insert_with(|| {
64                let name = room_name.to_string();
65                #[cfg(feature = "sqlite")]
66                {
67                    let db = self.db.clone();
68                    tokio::task::block_in_place(|| {
69                        tokio::runtime::Handle::current()
70                            .block_on(async { Arc::new(DocHandler::new(name, db).await) })
71                    })
72                }
73                #[cfg(not(feature = "sqlite"))]
74                {
75                    tokio::task::block_in_place(|| {
76                        tokio::runtime::Handle::current()
77                            .block_on(async { Arc::new(DocHandler::new(name).await) })
78                    })
79                }
80            })
81            .clone()
82    }
83}
84
85/// Create the sync router (for embedding in other servers)
86#[cfg(feature = "server")]
87pub fn create_router(state: Arc<AppState>) -> Router {
88    Router::new()
89        .route("/sync/:room_name", get(ws_handler))
90        .route("/sync", get(ws_handler_generic))
91        .with_state(state)
92}
93
94// WebSocket handlers
95#[cfg(feature = "server")]
96async fn ws_handler(
97    ws: WebSocketUpgrade,
98    axum::extract::Path(room_name): axum::extract::Path<String>,
99    State(state): State<Arc<AppState>>,
100) -> Response {
101    ws.on_upgrade(move |socket| handle_socket_with_room(socket, state, room_name))
102}
103
104#[cfg(feature = "server")]
105async fn ws_handler_generic(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> Response {
106    ws.on_upgrade(move |socket| handle_socket_generic(socket, state))
107}
108
109#[cfg(feature = "server")]
110async fn handle_socket_with_room(socket: WebSocket, state: Arc<AppState>, room_name: String) {
111    let handler = state.get_or_create_handler(&room_name);
112    let (sender, receiver) = socket.split();
113    run_connection(sender, receiver, handler, room_name, None).await;
114}
115
116#[cfg(feature = "server")]
117async fn handle_socket_generic(socket: WebSocket, state: Arc<AppState>) {
118    let (mut sender, mut receiver) = socket.split();
119
120    // Wait for first message to determine room name
121    let first_msg = match receiver.next().await {
122        Some(Ok(Message::Binary(data))) => data,
123        _ => return,
124    };
125
126    let (_, room_name) = match DocHandler::read_and_skip_doc_name(&first_msg) {
127        Some(res) => res,
128        None => return,
129    };
130
131    let handler = state.get_or_create_handler(&room_name);
132
133    // Process initial message
134    let responses = handler.handle_message(&first_msg).await;
135    for resp in &responses {
136        if sender
137            .send(Message::Binary(resp.clone().into()))
138            .await
139            .is_err()
140        {
141            return;
142        }
143    }
144
145    run_connection(
146        sender,
147        receiver,
148        handler,
149        room_name,
150        Some(first_msg.to_vec()),
151    )
152    .await;
153}
154
155#[cfg(feature = "server")]
156pub async fn run_connection(
157    mut ws_sender: SplitSink<WebSocket, Message>,
158    mut ws_receiver: SplitStream<WebSocket>,
159    handler: Arc<DocHandler>,
160    room_name: String,
161    _initial_message: Option<Vec<u8>>,
162) {
163    // Send initial sync
164    let initial_msgs = handler.generate_initial_sync();
165    for msg in initial_msgs {
166        if ws_sender.send(Message::Binary(msg.into())).await.is_err() {
167            return;
168        }
169    }
170
171    let mut broadcast_rx = handler.subscribe();
172
173    loop {
174        tokio::select! {
175            msg = ws_receiver.next() => {
176                match msg {
177                    Some(Ok(Message::Binary(data))) => {
178                        let responses = handler.handle_message(&data).await;
179                        for resp in responses {
180                            if ws_sender.send(Message::Binary(resp.into())).await.is_err() {
181                                return;
182                            }
183                        }
184                    }
185                    Some(Ok(Message::Ping(data))) => {
186                        let _ = ws_sender.send(Message::Pong(data)).await;
187                    }
188                    Some(Ok(Message::Close(_))) | None => {
189                        tracing::debug!("Client disconnected from room '{}'", room_name);
190                        return;
191                    }
192                    _ => {}
193                }
194            }
195            msg = broadcast_rx.recv() => {
196                if let Ok(data) = msg {
197                    if ws_sender.send(Message::Binary(data.into())).await.is_err() {
198                        return;
199                    }
200                }
201            }
202        }
203    }
204}