hocuspocus_rs_ws/
hocuspocus.rs

1// Portions of this module are adapted from the Hocuspocus JavaScript server
2// (https://github.com/ueberdosis/hocuspocus) and y-sweet
3// (https://github.com/y-sweet/y-sweet), both distributed under the MIT license.
4// Adapted code retains the original license terms.
5
6use crate::{
7    authenticator::Authenticator, client_connection::{ClientConnection, DocConnectionConfig, DocServer}, doc_sync::DocWithSyncKv, store::{memory::MemoryStore, Store}, sync::awareness::Awareness, sync_kv::SyncKv, types::HocuspocusConfiguration
8};
9use anyhow::{Result, anyhow};
10use async_trait::async_trait;
11use axum::{
12    Router,
13    extract::{
14        State,
15        ws::{Message as WsMessage, WebSocketUpgrade},
16    },
17    response::IntoResponse,
18    routing::get,
19};
20use dashmap::{DashMap, mapref::one::MappedRef};
21use futures::SinkExt;
22use futures_util::StreamExt;
23use std::sync::Arc;
24use std::{sync::RwLock, time::Duration};
25use tokio::sync::mpsc::{self, channel};
26use tokio::{net::TcpListener, sync::mpsc::Receiver};
27use tokio_util::{sync::CancellationToken, task::TaskTracker};
28use tracing::{Instrument, Level, info, span};
29
30pub type HocuspocusServer = Arc<Server>;
31
32pub struct Server {
33    docs: Arc<DashMap<String, DocWithSyncKv>>,
34    doc_worker_tracker: TaskTracker,
35    store: Arc<dyn Store>,
36    checkpoint_freq: Duration,
37    authenticator: Option<Arc<dyn Authenticator>>,
38    cancellation_token: CancellationToken,
39    doc_gc: bool,
40    port: u16,
41}
42
43impl Server {
44    pub fn new(
45        store: Arc<dyn Store>,
46        checkpoint_freq: Duration,
47        authenticator: Option<Arc<dyn Authenticator>>,
48        cancellation_token: CancellationToken,
49        doc_gc: bool,
50        port: u16,
51    ) -> Self {
52        Self {
53            docs: Arc::new(DashMap::new()),
54            doc_worker_tracker: TaskTracker::new(),
55            store,
56            checkpoint_freq,
57            authenticator,
58            cancellation_token,
59            doc_gc,
60            port,
61        }
62    }
63
64    pub async fn start(port: u16) -> anyhow::Result<()> {
65        let server = Arc::new(Server {
66            docs: Arc::new(DashMap::new()),
67            doc_worker_tracker: TaskTracker::new(),
68            store: Arc::new(MemoryStore::default()), // Some(Arc::new(Box::new(MemoryStore::default()))) ,
69            checkpoint_freq: Duration::from_secs(60), // 1분마다 GC 및 체크포인트
70            authenticator: None,
71            cancellation_token: CancellationToken::new(),
72            doc_gc: true,
73            port,
74        });
75        let app = Router::new()
76            .route("/", get(ws_handler))
77            .with_state(server.clone());
78
79        let addr = format!("0.0.0.0:{}", server.port);
80        let listener = TcpListener::bind(&addr).await?;
81
82        tracing::info!("Hocuspocus server listening on {}", addr);
83
84        axum::serve(listener, app).await?;
85
86        Ok(())
87    }
88
89    async fn load_doc(&self, doc_id: &str) -> Result<()> {
90        let (send, recv) = channel(1024);
91
92        let dwskv = DocWithSyncKv::new(doc_id, self.store.clone(), move || {
93            send.try_send(()).unwrap();
94        })
95        .await?;
96
97        dwskv
98            .sync_kv()
99            .persist()
100            .await
101            .map_err(|e| anyhow!("Error persisting: {:?}", e))?;
102
103        {
104            let sync_kv = dwskv.sync_kv();
105            let checkpoint_freq = self.checkpoint_freq;
106            let doc_id = doc_id.to_string();
107            let cancellation_token = self.cancellation_token.clone();
108
109            // Spawn a task to save the document to the store when it changes.
110            self.doc_worker_tracker.spawn(
111                Self::doc_persistence_worker(
112                    recv,
113                    sync_kv,
114                    checkpoint_freq,
115                    doc_id.clone(),
116                    cancellation_token.clone(),
117                )
118                .instrument(span!(Level::INFO, "save_loop", doc_id=?doc_id)),
119            );
120
121            if self.doc_gc {
122                self.doc_worker_tracker.spawn(
123                    Self::doc_gc_worker(
124                        self.docs.clone(),
125                        doc_id.clone(),
126                        checkpoint_freq,
127                        cancellation_token,
128                    )
129                    .instrument(span!(Level::INFO, "gc_loop", doc_id=?doc_id)),
130                );
131            }
132        }
133
134        self.docs.insert(doc_id.to_string(), dwskv);
135        Ok(())
136    }
137
138    async fn doc_gc_worker(
139        docs: Arc<DashMap<String, DocWithSyncKv>>,
140        doc_id: String,
141        checkpoint_freq: Duration,
142        cancellation_token: CancellationToken,
143    ) {
144        let mut checkpoints_without_refs = 0;
145
146        loop {
147            tokio::select! {
148                _ = tokio::time::sleep(checkpoint_freq) => {
149                    if let Some(doc) = docs.get(&doc_id) {
150                        let awareness = Arc::downgrade(&doc.awareness());
151                        if awareness.strong_count() > 1 {
152                            checkpoints_without_refs = 0;
153                            tracing::debug!("doc is still alive - it has {} references", awareness.strong_count());
154                        } else {
155                            checkpoints_without_refs += 1;
156                            tracing::info!("doc has only one reference, candidate for GC. checkpoints_without_refs: {}", checkpoints_without_refs);
157                        }
158                    } else {
159                        break;
160                    }
161
162                    if checkpoints_without_refs >= 2 {
163                        tracing::info!("GCing doc");
164                        if let Some(doc) = docs.get(&doc_id) {
165                            doc.sync_kv().shutdown();
166                        }
167
168                        docs.remove(&doc_id);
169                        break;
170                    }
171                }
172                _ = cancellation_token.cancelled() => {
173                    break;
174                }
175            };
176        }
177        tracing::info!("Exiting gc_loop");
178    }
179
180    async fn doc_persistence_worker(
181        mut recv: Receiver<()>,
182        sync_kv: Arc<SyncKv>,
183        checkpoint_freq: Duration,
184        doc_id: String,
185        cancellation_token: CancellationToken,
186    ) {
187        let mut last_save = std::time::Instant::now();
188
189        loop {
190            let is_done = tokio::select! {
191                v = recv.recv() => v.is_none(),
192                _ = cancellation_token.cancelled() => true,
193                _ = tokio::time::sleep(checkpoint_freq) => {
194                    sync_kv.is_shutdown()
195                }
196            };
197
198            tracing::info!("Received signal. done: {}", is_done);
199            let now = std::time::Instant::now();
200            if !is_done && now - last_save < checkpoint_freq {
201                let sleep = tokio::time::sleep(checkpoint_freq - (now - last_save));
202                tokio::pin!(sleep);
203                tracing::info!("Throttling.");
204
205                loop {
206                    tokio::select! {
207                        _ = &mut sleep => {
208                            break;
209                        }
210                        v = recv.recv() => {
211                            tracing::info!("Received dirty while throttling.");
212                            if v.is_none() {
213                                break;
214                            }
215                        }
216                        _ = cancellation_token.cancelled() => {
217                            tracing::info!("Received cancellation while throttling.");
218                            break;
219                        }
220
221                    }
222                    tracing::info!("Done throttling.");
223                }
224            }
225            tracing::info!("Persisting.");
226            if let Err(e) = sync_kv.persist().await {
227                tracing::error!(?e, "Error persisting.");
228            } else {
229                tracing::info!("Done persisting.");
230            }
231            last_save = std::time::Instant::now();
232
233            if is_done {
234                break;
235            }
236        }
237        tracing::info!("Terminating loop for {}", doc_id);
238    }
239
240    pub async fn get_or_create_doc(
241        &self,
242        doc_id: &str,
243    ) -> Result<MappedRef<String, DocWithSyncKv, DocWithSyncKv>> {
244        if !self.docs.contains_key(doc_id) {
245            tracing::info!(doc_id=?doc_id, "Loading doc");
246            self.load_doc(doc_id).await?;
247        }
248
249        Ok(self
250            .docs
251            .get(doc_id)
252            .ok_or_else(|| anyhow!("Failed to get-or-create doc"))?
253            .map(|d| d))
254    }
255
256}
257
258#[async_trait]
259impl DocServer for Server {
260    async fn fetch(&self, doc_id: &str) -> Result<Arc<RwLock<Awareness>>> {
261        Ok(self.get_or_create_doc(doc_id).await?.awareness())
262    }
263
264    async fn authenticate(&self, doc_id: &str, token: &str) -> Result<DocConnectionConfig> {
265        if let Some(auth) = &self.authenticator {
266            Ok(auth.authenticate(doc_id, token).await?)
267        } else {
268            Ok(DocConnectionConfig::default())
269        }
270    }
271}
272
273async fn ws_handler(
274    ws: WebSocketUpgrade,
275    State(hocuspocus): State<Arc<Server>>,
276    _request: axum::http::Request<axum::body::Body>,
277) -> impl IntoResponse {
278    // let document_name = document.unwrap_or_else(|| "default".to_string());
279    // let document_name = "".to_string(); // Default document name
280
281    // Trigger onUpgrade hooks
282    // for extension in &hocuspocus.configuration.extensions {
283    //     if let Err(e) = extension.on_upgrade(&request).await {
284    //         tracing::error!("onUpgrade hook failed: {}", e);
285    //         return StatusCode::INTERNAL_SERVER_ERROR.into_response();
286    //     }
287    // }
288
289    ws.on_upgrade(move |socket| handle_websocket_upgrade(socket, hocuspocus))
290}
291
292async fn handle_websocket_upgrade(
293    socket: axum::extract::ws::WebSocket,
294    hocuspocus: Arc<Server>,
295) {
296    tracing::debug!("handle_websocket_upgrade : {:?}", socket);
297
298    let (_close_tx, _close_rx) = mpsc::channel::<()>(1);
299    let (mut sink, stream) = socket.split();
300
301    let hocuspocus_clone = hocuspocus.clone();
302    let (tx_to_ws, mut rx_to_ws) = mpsc::channel(16);
303
304    let client_connection = ClientConnection::new(
305        hocuspocus_clone,
306        tx_to_ws.clone(),
307        Duration::from_secs(300),
308        Default::default(),
309    );
310
311    tokio::spawn(async move {
312        loop {
313            match rx_to_ws.recv().await {
314                Some(msg) => {
315                    let _ = sink.send(WsMessage::Binary(msg.into())).await;
316                }
317                None => {
318                    info!("client connection already closed");
319                    return;
320                }
321            }
322        }
323    });
324
325    let mut stream = stream;
326    loop {
327        match stream.next().await {
328            Some(Ok(WsMessage::Binary(data))) => {
329                tracing::debug!("Received buffer: {:?}", data);
330                let result = client_connection.handle_message(&data).await;
331
332                if let Err(e) = result {
333                    tracing::warn!("Failed to handle message: {}", e);
334                }
335                // if let Err(e) = message_receiver
336                //     .handle_ws_bytes(&document.clone(), data)
337                //     .await
338                // {
339                //     tracing::error!("Failed to handle message: {}", e);
340                // }
341            }
342            Some(Ok(WsMessage::Close(_))) => {
343                drop(stream);
344                break;
345            }
346            Some(Err(e)) => {
347                drop(stream);
348                tracing::error!("WebSocket error: {}", e);
349                break;
350            }
351            None => {
352                drop(stream);
353                break;
354            }
355            _ => {
356                // Ignore other message types
357            }
358        }
359    }
360}
361
362pub async fn start_server(
363    _configuration: HocuspocusConfiguration,
364    port: u16,
365) -> anyhow::Result<()> {
366    // let hocuspocus = Hocuspocus::new(configuration).await?;
367    // let store = if let Some(store) = store {
368    //     let store = get_store_from_opts(store)?;
369    //     store.init().await?;
370    //     Some(store)
371    // } else {
372    //     tracing::warn!("No store set. Documents will be stored in memory only.");
373    //     None
374    // };
375
376    // if !prod {
377    //     print_server_url(auth.as_ref(), url_prefix.as_ref(), addr);
378    // }
379
380    // let token = CancellationToken::new();
381
382    // let auth = if let Some(auth) = Some("<auth_key>") {
383    //     Some(Authenticator::new(auth)?)
384    // } else {
385    //     tracing::warn!("No auth key set. Only use this for local development!");
386    //     None
387    // };
388
389    // let server = Server::new(
390    //     None,
391    //     std::time::Duration::from_secs(10),
392    //     None,
393    //     // url_prefix.clone(),
394    //     None,
395    //     token.clone(),
396    //     true,
397    //     // *max_body_size,
398    //     None,
399    //     port,
400    // );
401
402    Server::start(port).await
403}