1use crate::{
7 authenticator::Authenticator, client_connection::{ClientConnection, DocConnectionConfig, DocServer}, doc_sync::DocWithSyncKv, store::{memory::MemoryStore, Store}, sync::awareness::Awareness, sync_kv::SyncKv, types::HocuspocusConfiguration
8};
9use anyhow::{Result, anyhow};
10use async_trait::async_trait;
11use axum::{
12 Router,
13 extract::{
14 State,
15 ws::{Message as WsMessage, WebSocketUpgrade},
16 },
17 response::IntoResponse,
18 routing::get,
19};
20use dashmap::{DashMap, mapref::one::MappedRef};
21use futures::SinkExt;
22use futures_util::StreamExt;
23use std::sync::Arc;
24use std::{sync::RwLock, time::Duration};
25use tokio::sync::mpsc::{self, channel};
26use tokio::{net::TcpListener, sync::mpsc::Receiver};
27use tokio_util::{sync::CancellationToken, task::TaskTracker};
28use tracing::{Instrument, Level, info, span};
29
30pub type HocuspocusServer = Arc<Server>;
31
32pub struct Server {
33 docs: Arc<DashMap<String, DocWithSyncKv>>,
34 doc_worker_tracker: TaskTracker,
35 store: Arc<dyn Store>,
36 checkpoint_freq: Duration,
37 authenticator: Option<Arc<dyn Authenticator>>,
38 cancellation_token: CancellationToken,
39 doc_gc: bool,
40 port: u16,
41}
42
43impl Server {
44 pub fn new(
45 store: Arc<dyn Store>,
46 checkpoint_freq: Duration,
47 authenticator: Option<Arc<dyn Authenticator>>,
48 cancellation_token: CancellationToken,
49 doc_gc: bool,
50 port: u16,
51 ) -> Self {
52 Self {
53 docs: Arc::new(DashMap::new()),
54 doc_worker_tracker: TaskTracker::new(),
55 store,
56 checkpoint_freq,
57 authenticator,
58 cancellation_token,
59 doc_gc,
60 port,
61 }
62 }
63
64 pub async fn start(port: u16) -> anyhow::Result<()> {
65 let server = Arc::new(Server {
66 docs: Arc::new(DashMap::new()),
67 doc_worker_tracker: TaskTracker::new(),
68 store: Arc::new(MemoryStore::default()), checkpoint_freq: Duration::from_secs(60), authenticator: None,
71 cancellation_token: CancellationToken::new(),
72 doc_gc: true,
73 port,
74 });
75 let app = Router::new()
76 .route("/", get(ws_handler))
77 .with_state(server.clone());
78
79 let addr = format!("0.0.0.0:{}", server.port);
80 let listener = TcpListener::bind(&addr).await?;
81
82 tracing::info!("Hocuspocus server listening on {}", addr);
83
84 axum::serve(listener, app).await?;
85
86 Ok(())
87 }
88
89 async fn load_doc(&self, doc_id: &str) -> Result<()> {
90 let (send, recv) = channel(1024);
91
92 let dwskv = DocWithSyncKv::new(doc_id, self.store.clone(), move || {
93 send.try_send(()).unwrap();
94 })
95 .await?;
96
97 dwskv
98 .sync_kv()
99 .persist()
100 .await
101 .map_err(|e| anyhow!("Error persisting: {:?}", e))?;
102
103 {
104 let sync_kv = dwskv.sync_kv();
105 let checkpoint_freq = self.checkpoint_freq;
106 let doc_id = doc_id.to_string();
107 let cancellation_token = self.cancellation_token.clone();
108
109 self.doc_worker_tracker.spawn(
111 Self::doc_persistence_worker(
112 recv,
113 sync_kv,
114 checkpoint_freq,
115 doc_id.clone(),
116 cancellation_token.clone(),
117 )
118 .instrument(span!(Level::INFO, "save_loop", doc_id=?doc_id)),
119 );
120
121 if self.doc_gc {
122 self.doc_worker_tracker.spawn(
123 Self::doc_gc_worker(
124 self.docs.clone(),
125 doc_id.clone(),
126 checkpoint_freq,
127 cancellation_token,
128 )
129 .instrument(span!(Level::INFO, "gc_loop", doc_id=?doc_id)),
130 );
131 }
132 }
133
134 self.docs.insert(doc_id.to_string(), dwskv);
135 Ok(())
136 }
137
138 async fn doc_gc_worker(
139 docs: Arc<DashMap<String, DocWithSyncKv>>,
140 doc_id: String,
141 checkpoint_freq: Duration,
142 cancellation_token: CancellationToken,
143 ) {
144 let mut checkpoints_without_refs = 0;
145
146 loop {
147 tokio::select! {
148 _ = tokio::time::sleep(checkpoint_freq) => {
149 if let Some(doc) = docs.get(&doc_id) {
150 let awareness = Arc::downgrade(&doc.awareness());
151 if awareness.strong_count() > 1 {
152 checkpoints_without_refs = 0;
153 tracing::debug!("doc is still alive - it has {} references", awareness.strong_count());
154 } else {
155 checkpoints_without_refs += 1;
156 tracing::info!("doc has only one reference, candidate for GC. checkpoints_without_refs: {}", checkpoints_without_refs);
157 }
158 } else {
159 break;
160 }
161
162 if checkpoints_without_refs >= 2 {
163 tracing::info!("GCing doc");
164 if let Some(doc) = docs.get(&doc_id) {
165 doc.sync_kv().shutdown();
166 }
167
168 docs.remove(&doc_id);
169 break;
170 }
171 }
172 _ = cancellation_token.cancelled() => {
173 break;
174 }
175 };
176 }
177 tracing::info!("Exiting gc_loop");
178 }
179
180 async fn doc_persistence_worker(
181 mut recv: Receiver<()>,
182 sync_kv: Arc<SyncKv>,
183 checkpoint_freq: Duration,
184 doc_id: String,
185 cancellation_token: CancellationToken,
186 ) {
187 let mut last_save = std::time::Instant::now();
188
189 loop {
190 let is_done = tokio::select! {
191 v = recv.recv() => v.is_none(),
192 _ = cancellation_token.cancelled() => true,
193 _ = tokio::time::sleep(checkpoint_freq) => {
194 sync_kv.is_shutdown()
195 }
196 };
197
198 tracing::info!("Received signal. done: {}", is_done);
199 let now = std::time::Instant::now();
200 if !is_done && now - last_save < checkpoint_freq {
201 let sleep = tokio::time::sleep(checkpoint_freq - (now - last_save));
202 tokio::pin!(sleep);
203 tracing::info!("Throttling.");
204
205 loop {
206 tokio::select! {
207 _ = &mut sleep => {
208 break;
209 }
210 v = recv.recv() => {
211 tracing::info!("Received dirty while throttling.");
212 if v.is_none() {
213 break;
214 }
215 }
216 _ = cancellation_token.cancelled() => {
217 tracing::info!("Received cancellation while throttling.");
218 break;
219 }
220
221 }
222 tracing::info!("Done throttling.");
223 }
224 }
225 tracing::info!("Persisting.");
226 if let Err(e) = sync_kv.persist().await {
227 tracing::error!(?e, "Error persisting.");
228 } else {
229 tracing::info!("Done persisting.");
230 }
231 last_save = std::time::Instant::now();
232
233 if is_done {
234 break;
235 }
236 }
237 tracing::info!("Terminating loop for {}", doc_id);
238 }
239
240 pub async fn get_or_create_doc(
241 &self,
242 doc_id: &str,
243 ) -> Result<MappedRef<String, DocWithSyncKv, DocWithSyncKv>> {
244 if !self.docs.contains_key(doc_id) {
245 tracing::info!(doc_id=?doc_id, "Loading doc");
246 self.load_doc(doc_id).await?;
247 }
248
249 Ok(self
250 .docs
251 .get(doc_id)
252 .ok_or_else(|| anyhow!("Failed to get-or-create doc"))?
253 .map(|d| d))
254 }
255
256}
257
258#[async_trait]
259impl DocServer for Server {
260 async fn fetch(&self, doc_id: &str) -> Result<Arc<RwLock<Awareness>>> {
261 Ok(self.get_or_create_doc(doc_id).await?.awareness())
262 }
263
264 async fn authenticate(&self, doc_id: &str, token: &str) -> Result<DocConnectionConfig> {
265 if let Some(auth) = &self.authenticator {
266 Ok(auth.authenticate(doc_id, token).await?)
267 } else {
268 Ok(DocConnectionConfig::default())
269 }
270 }
271}
272
273async fn ws_handler(
274 ws: WebSocketUpgrade,
275 State(hocuspocus): State<Arc<Server>>,
276 _request: axum::http::Request<axum::body::Body>,
277) -> impl IntoResponse {
278 ws.on_upgrade(move |socket| handle_websocket_upgrade(socket, hocuspocus))
290}
291
292async fn handle_websocket_upgrade(
293 socket: axum::extract::ws::WebSocket,
294 hocuspocus: Arc<Server>,
295) {
296 tracing::debug!("handle_websocket_upgrade : {:?}", socket);
297
298 let (_close_tx, _close_rx) = mpsc::channel::<()>(1);
299 let (mut sink, stream) = socket.split();
300
301 let hocuspocus_clone = hocuspocus.clone();
302 let (tx_to_ws, mut rx_to_ws) = mpsc::channel(16);
303
304 let client_connection = ClientConnection::new(
305 hocuspocus_clone,
306 tx_to_ws.clone(),
307 Duration::from_secs(300),
308 Default::default(),
309 );
310
311 tokio::spawn(async move {
312 loop {
313 match rx_to_ws.recv().await {
314 Some(msg) => {
315 let _ = sink.send(WsMessage::Binary(msg.into())).await;
316 }
317 None => {
318 info!("client connection already closed");
319 return;
320 }
321 }
322 }
323 });
324
325 let mut stream = stream;
326 loop {
327 match stream.next().await {
328 Some(Ok(WsMessage::Binary(data))) => {
329 tracing::debug!("Received buffer: {:?}", data);
330 let result = client_connection.handle_message(&data).await;
331
332 if let Err(e) = result {
333 tracing::warn!("Failed to handle message: {}", e);
334 }
335 }
342 Some(Ok(WsMessage::Close(_))) => {
343 drop(stream);
344 break;
345 }
346 Some(Err(e)) => {
347 drop(stream);
348 tracing::error!("WebSocket error: {}", e);
349 break;
350 }
351 None => {
352 drop(stream);
353 break;
354 }
355 _ => {
356 }
358 }
359 }
360}
361
362pub async fn start_server(
363 _configuration: HocuspocusConfiguration,
364 port: u16,
365) -> anyhow::Result<()> {
366 Server::start(port).await
403}