1use std::sync::Arc;
2use std::time::{Duration, SystemTime, UNIX_EPOCH};
3
4use anyhow::Result;
5use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
6use axum::extract::State;
7use axum::response::IntoResponse;
8
9use dashmap::DashMap;
10use futures_util::StreamExt;
11use hocuspocus_extension_database::types::{FetchContext, StoreContext};
12use hocuspocus_extension_database::DatabaseExtension;
13use tokio::sync::mpsc as tokio_mpsc;
14use tokio::time::{sleep_until, Instant};
15use yrs::encoding::read::{Cursor as YCursor, Read as YRead};
16use yrs::encoding::write::Write as YWrite;
17use yrs::sync::{Awareness, DefaultProtocol, Message as YMsg, Protocol};
18use yrs::updates::decoder::Decode;
19use yrs::updates::encoder::Encode;
20use yrs::{Doc, ReadTxn, StateVector, Transact, Update};
21
22#[cfg(feature = "redis")]
23use hocuspocus_extension_redis::RedisBroadcaster;
24
25pub struct DocRegistry {
26 doc_counts: DashMap<String, usize>,
27 doc_latest: DashMap<String, Vec<u8>>, }
29
30impl DocRegistry {
31 pub fn new() -> Self {
32 Self {
33 doc_counts: DashMap::new(),
34 doc_latest: DashMap::new(),
35 }
36 }
37
38 pub fn increment(&self, name: &str) -> usize {
39 let entry = self
40 .doc_counts
41 .entry(name.to_string())
42 .and_modify(|c| *c += 1)
43 .or_insert(1);
44 *entry
45 }
46
47 pub fn decrement(&self, name: &str) -> Option<usize> {
48 if let Some(mut entry) = self.doc_counts.get_mut(name) {
49 if *entry > 0 {
50 *entry -= 1;
51 }
52 let remaining = *entry;
53 drop(entry);
54 Some(remaining)
55 } else {
56 None
57 }
58 }
59
60 pub fn set_latest(&self, name: &str, bytes: Vec<u8>) {
61 self.doc_latest.insert(name.to_string(), bytes);
62 }
63
64 pub fn get_latest_cloned(&self, name: &str) -> Option<Vec<u8>> {
65 self.doc_latest.get(name).map(|v| v.clone())
66 }
67
68 pub fn remove(&self, name: &str) {
69 self.doc_counts.remove(name);
70 self.doc_latest.remove(name);
71 }
72}
73
74pub struct AppState<E: DatabaseExtension> {
75 pub db: Arc<E>,
76 pub debounce_ms: u64,
77 pub max_debounce_ms: u64,
78 pub doc_registry: DocRegistry,
79 #[cfg(feature = "redis")]
80 pub redis: Option<Arc<RedisBroadcaster>>, pub auth: Option<Arc<dyn AuthProvider + Send + Sync>>,
83}
84
85pub async fn ws_handler<E: DatabaseExtension + 'static>(
86 ws: WebSocketUpgrade,
87 State(state): State<Arc<AppState<E>>>,
88) -> impl IntoResponse {
89 tracing::debug!("upgrade request");
90 ws.on_upgrade(move |socket| on_ws::<E>(socket, state))
91}
92
93const MSG_SYNC: u32 = 0;
94const MSG_AWARENESS: u32 = 1;
95const MSG_AUTH: u32 = 2;
96const MSG_QUERY_AWARENESS: u32 = 3;
97const MSG_SYNC_STATUS: u32 = 8;
98
99enum WorkerCmd {
100 ApplyState(Vec<u8>),
101 InboundWs(Vec<u8>),
102 SetReadonly(bool),
103 Stop,
104}
105
106enum WorkerEvent {
107 OutgoingWs(Vec<u8>),
108 StoreState(Vec<u8>),
109}
110
111fn worker_thread(
112 cmd_rx: std::sync::mpsc::Receiver<WorkerCmd>,
113 ev_tx: tokio_mpsc::Sender<WorkerEvent>,
114) {
115 let doc = Doc::new();
129 let mut awareness = Awareness::new(doc.clone());
130 let protocol = DefaultProtocol;
131 let mut is_readonly = false;
132
133 let ev_tx_updates = ev_tx.clone();
135 let doc_for_obs = doc.clone();
136 let doc_for_txn = doc.clone();
137 doc_for_obs
138 .observe_update_v1(move |_txn, _u| {
139 let bytes = {
140 let txn = doc_for_txn.transact();
141 let sv = StateVector::default();
142 txn.encode_state_as_update_v1(&sv)
143 };
144 let _ = ev_tx_updates.try_send(WorkerEvent::StoreState(bytes));
145 })
146 .ok();
147
148 while let Ok(cmd) = cmd_rx.recv() {
149 match cmd {
150 WorkerCmd::ApplyState(bytes) => {
151 if let Ok(update) = Update::decode_v1(&bytes) {
152 if let Err(e) = doc.transact_mut().apply_update(update) {
153 tracing::warn!(?e, "failed to apply initial state");
154 }
155 }
156 }
157 WorkerCmd::InboundWs(data) => {
158 if data.is_empty() {
159 continue;
160 }
161 let mut cur = YCursor::new(&data);
163 let frame_doc_name = cur.read_string().unwrap_or("").to_string();
164 let t: u32 = cur.read_var().unwrap_or(0);
165 tracing::debug!(doc_name = %frame_doc_name, msg_type = t, len = data.len(), "inbound frame");
166 let body = &data[cur.next..];
167 match t {
168 MSG_SYNC => {
169 if body.is_empty() {
171 tracing::debug!("empty y-sync body");
172 } else {
173 let mut bcur = YCursor::new(body);
174 let subtag: u32 = bcur.read_var().unwrap_or(u32::MAX);
175 match subtag {
176 0 => {
177 if let Ok(sv_bytes) = bcur.read_buf() {
179 if let Ok(sv) = StateVector::decode_v1(sv_bytes) {
180 let update = {
182 let txn = doc.transact();
183 txn.encode_state_as_update_v1(&sv)
184 };
185 let mut out = Vec::new();
186 out.write_string(&frame_doc_name);
187 out.write_var(MSG_SYNC);
188 out.write_var(1u32);
190 out.write_buf(&update);
192 tracing::debug!(len = out.len(), out_doc = %frame_doc_name, manual = true, "outbound sync step2 reply");
193 let _ =
194 ev_tx.blocking_send(WorkerEvent::OutgoingWs(out));
195 }
196 }
197 }
198 1 | 2 => {
199 if let Ok(upd_bytes) = bcur.read_buf() {
201 if is_readonly {
202 let mut ack = Vec::new();
204 ack.write_string(&frame_doc_name);
205 ack.write_var(MSG_SYNC_STATUS);
206 ack.write_var(1u32);
207 tracing::debug!(out_doc = %frame_doc_name, manual = true, "ack in readonly mode (no apply)");
208 let _ =
209 ev_tx.blocking_send(WorkerEvent::OutgoingWs(ack));
210 } else {
211 if let Ok(update) = Update::decode_v1(upd_bytes) {
213 if let Err(e) =
214 doc.transact_mut().apply_update(update)
215 {
216 tracing::warn!(?e, "failed to apply update");
217 }
218 let mut ack = Vec::new();
220 ack.write_string(&frame_doc_name);
221 ack.write_var(MSG_SYNC_STATUS);
222 ack.write_var(1u32);
223 tracing::debug!(out_doc = %frame_doc_name, manual = true, "outbound sync status ack");
224 let _ = ev_tx
225 .blocking_send(WorkerEvent::OutgoingWs(ack));
226
227 let bytes = {
229 let txn = doc.transact();
230 let sv = StateVector::default();
231 txn.encode_state_as_update_v1(&sv)
232 };
233 let _ = ev_tx
234 .blocking_send(WorkerEvent::StoreState(bytes));
235 } else {
236 tracing::debug!(
237 "failed to decode update bytes (v1 only)"
238 );
239 }
240 }
241 }
242 }
243 _ => {
244 tracing::debug!(subtag, "unknown y-sync submessage");
245 }
246 }
247 }
248 }
249 2 => {
250 tracing::debug!("auth message received by worker; ignoring");
252 }
253 MSG_AWARENESS => {
254 if let Ok(inner) = YMsg::decode_v1(body) {
255 if let YMsg::Awareness(update) = inner {
256 let reply = protocol
257 .handle_awareness_update(&mut awareness, update)
258 .ok()
259 .flatten();
260 if let Some(msg) = reply {
261 let mut out = Vec::new();
262 out.write_string(&frame_doc_name);
263 out.write_var(MSG_AWARENESS);
264 out.extend(msg.encode_v1());
265 tracing::debug!(len = out.len(), out_doc = %frame_doc_name, "outbound awareness echo");
266 let _ = ev_tx.blocking_send(WorkerEvent::OutgoingWs(out));
267 }
268 }
269 } else {
270 tracing::debug!("failed to decode awareness message");
271 }
272 }
273 MSG_QUERY_AWARENESS => {
274 if let Ok(Some(reply)) = protocol.handle_awareness_query(&awareness) {
275 let mut out = Vec::new();
276 out.write_string(&frame_doc_name);
277 out.write_var(MSG_AWARENESS);
278 out.extend(reply.encode_v1());
279 tracing::debug!(len = out.len(), out_doc = %frame_doc_name, "outbound awareness reply");
280 let _ = ev_tx.blocking_send(WorkerEvent::OutgoingWs(out));
281 }
282 }
283 _ => {}
284 }
285 }
286 WorkerCmd::SetReadonly(ro) => {
287 is_readonly = ro;
288 }
289 WorkerCmd::Stop => break,
290 }
291 }
292}
293
294async fn on_ws<E: DatabaseExtension + 'static>(mut socket: WebSocket, state: Arc<AppState<E>>) {
295 tracing::debug!("connection established");
296
297 let (cmd_tx, cmd_rx) = std::sync::mpsc::channel::<WorkerCmd>();
299 let (ev_tx, mut ev_rx) = tokio_mpsc::channel::<WorkerEvent>(64);
300
301 std::thread::spawn(move || worker_thread(cmd_rx, ev_tx));
303
304 let mut pending = false;
306 let mut first_pending_at: Option<Instant> = None;
307 let mut next_deadline: Option<Instant> = None;
308 let mut latest_state_bytes: Option<Vec<u8>> = None;
309 let mut selected_doc_name: Option<String> = None;
310 let mut is_authenticated: bool = state.auth.is_none();
311 let mut loaded_state: bool = false;
312 #[cfg(feature = "redis")]
313 let mut redis_sub_handle: Option<tokio::task::JoinHandle<()>> = None;
314
315 loop {
316 let sleep_fut = if let Some(deadline) = next_deadline {
317 Some(Box::pin(sleep_until(deadline)))
318 } else {
319 None
320 };
321
322 tokio::select! {
323 maybe_msg = socket.next() => {
325 match maybe_msg {
326 Some(Ok(Message::Binary(b))) => {
327 let mut handled_by_auth = false;
329 if selected_doc_name.is_none() {
330 let mut cur = YCursor::new(b.as_ref());
331 if let Ok(name_str) = cur.read_string() {
332 let name = name_str.to_string();
333 tracing::debug!(document_name = %name, "first frame received");
334 let _ = state.doc_registry.increment(&name);
336 selected_doc_name = Some(name.clone());
337
338 if state.auth.is_some() && !is_authenticated {
340 let mtype: u32 = cur.read_var().unwrap_or(u32::MAX);
342 if mtype == MSG_AUTH {
343 let sub: u32 = cur.read_var().unwrap_or(u32::MAX);
345 if sub == 0 {
346 if let Ok(token_str) = cur.read_string() {
348 if let Some(provider) = state.auth.as_ref() {
349 match provider.on_authenticate(&name, &token_str) {
350 Ok(scope) => {
351 let readonly = matches!(scope, AuthScope::ReadOnly);
353 send_auth_authenticated(&mut socket, &name, readonly).await;
354 is_authenticated = true;
355 let _ = cmd_tx.send(WorkerCmd::SetReadonly(readonly));
356 handled_by_auth = true; }
358 Err(e) => {
359 send_auth_permission_denied(&mut socket, &name, Some(&format!("{}", e))).await;
360 send_auth_token_request(&mut socket, &name).await;
362 handled_by_auth = true;
363 }
364 }
365 }
366 }
367 } else {
368 send_auth_permission_denied(&mut socket, &name, Some("invalid-auth-message")).await;
370 send_auth_token_request(&mut socket, &name).await;
372 handled_by_auth = true;
373 }
374 } else {
375 send_auth_token_request(&mut socket, &name).await;
377 handled_by_auth = true; }
379 }
380
381 #[cfg(feature = "redis")]
383 if let Some(bc) = state.redis.as_ref() {
384 let (handle, mut rx) = bc.subscribe(name.clone()).await.expect("redis subscribe");
385 redis_sub_handle = Some(handle);
386 let tx_clone = cmd_tx.clone();
387 tokio::spawn(async move {
388 while let Some((_is_sync, body)) = rx.recv().await {
389 let _ = tx_clone.send(WorkerCmd::InboundWs(body.to_vec()));
390 }
391 });
392 }
393 } else {
394 tracing::debug!("failed to read document name from first frame");
395 }
396 } else if state.auth.is_some() && !is_authenticated {
397 let mut cur = YCursor::new(b.as_ref());
399 let _ = cur.read_string(); let mtype: u32 = cur.read_var().unwrap_or(u32::MAX);
401 if mtype == MSG_AUTH {
402 let name = selected_doc_name.as_ref().unwrap().clone();
403 let sub: u32 = cur.read_var().unwrap_or(u32::MAX);
404 if sub == 0 {
405 if let Ok(token_str) = cur.read_string() {
406 if let Some(provider) = state.auth.as_ref() {
407 match provider.on_authenticate(&name, &token_str) {
408 Ok(scope) => {
409 let readonly = matches!(scope, AuthScope::ReadOnly);
410 send_auth_authenticated(&mut socket, &name, readonly).await;
411 is_authenticated = true;
412 let _ = cmd_tx.send(WorkerCmd::SetReadonly(readonly));
413 handled_by_auth = true;
414 }
415 Err(e) => {
416 send_auth_permission_denied(&mut socket, &name, Some(&format!("{}", e))).await;
417 send_auth_token_request(&mut socket, &name).await;
419 handled_by_auth = true;
420 }
421 }
422 }
423 }
424 } else {
425 send_auth_permission_denied(&mut socket, &name, Some("invalid-auth-message")).await;
426 send_auth_token_request(&mut socket, &name).await;
428 handled_by_auth = true;
429 }
430 } else {
431 let name = selected_doc_name.as_ref().unwrap().clone();
433 send_auth_token_request(&mut socket, &name).await;
434 handled_by_auth = true;
435 }
436 }
437
438 if !loaded_state {
440 if let Some(name) = selected_doc_name.as_ref() {
441 if is_authenticated {
442 tracing::debug!(document_name = %name, "loading state after auth");
443 if let Err(e) = load_and_send_state(&*state.db, name, &cmd_tx).await {
444 tracing::warn!(error = %e, document_name = %name, "failed to load/apply state");
445 }
446 loaded_state = true;
447 }
448 }
449 }
450
451 if !handled_by_auth {
452 let _ = cmd_tx.send(WorkerCmd::InboundWs(b.to_vec()));
453 }
454 }
455 Some(Ok(Message::Ping(p))) => {
456 tracing::debug!(size = %p.len(), "ping received");
457 let _ = socket.send(Message::Pong(p)).await;
458 }
459 Some(Ok(Message::Close(frame))) => {
460 tracing::debug!(pending = pending, ?frame, "closing connection");
461 if let Some(name) = selected_doc_name.as_ref() {
463 if let Some(remaining) = state.doc_registry.decrement(name) {
464 if remaining == 0 {
465 let to_store = latest_state_bytes
466 .as_ref()
467 .cloned()
468 .or_else(|| state.doc_registry.get_latest_cloned(name));
469 if let Some(bytes) = to_store {
470 tracing::debug!(document_name = %name, "last client left; force storing state");
471 let _ = store_bytes(&*state.db, name, &bytes).await;
472 }
473 state.doc_registry.remove(name);
474 }
475 }
476 }
477 if pending {
478 tracing::debug!(pending = pending, "storing state on close");
479 if let (Some(bytes), Some(name)) = (latest_state_bytes.as_ref(), selected_doc_name.as_ref()) {
480 let _ = store_bytes(&*state.db, name, bytes).await;
481 }
482 }
483 let _ = cmd_tx.send(WorkerCmd::Stop);
484 #[cfg(feature = "redis")]
485 if let Some(h) = redis_sub_handle.take() { h.abort(); }
486 break;
487 }
488 None => {
489 tracing::debug!(pending = pending, "socket closed by peer");
490 if let Some(name) = selected_doc_name.as_ref() {
492 if let Some(remaining) = state.doc_registry.decrement(name) {
493 if remaining == 0 {
494 let to_store = latest_state_bytes
495 .as_ref()
496 .cloned()
497 .or_else(|| state.doc_registry.get_latest_cloned(name));
498 if let Some(bytes) = to_store {
499 tracing::debug!(document_name = %name, "last client left; force storing state");
500 let _ = store_bytes(&*state.db, name, &bytes).await;
501 }
502 state.doc_registry.remove(name);
503 }
504 }
505 }
506 if pending {
507 tracing::debug!(pending = pending, "storing state on close");
508 if let (Some(bytes), Some(name)) = (latest_state_bytes.as_ref(), selected_doc_name.as_ref()) {
509 let _ = store_bytes(&*state.db, name, bytes).await;
510 }
511 }
512 let _ = cmd_tx.send(WorkerCmd::Stop);
513 #[cfg(feature = "redis")]
514 if let Some(h) = redis_sub_handle.take() { h.abort(); }
515 break;
516 }
517 _ => {}
518 }
519 }
520 Some(ev) = ev_rx.recv() => {
522 match ev {
523 WorkerEvent::OutgoingWs(bytes) => {
524 tracing::debug!(len = bytes.len(), "sending binary to client");
525 let _ = socket
526 .send(Message::Binary(axum::body::Bytes::from(bytes.clone())))
527 .await;
528 #[cfg(feature = "redis")]
529 if let (Some(name), Some(bc)) = (selected_doc_name.as_ref(), state.redis.as_ref()) {
530 let mut cur = YCursor::new(&bytes);
532 let _ = cur.read_string();
533 let mtype: u32 = cur.read_var().unwrap_or(u32::MAX);
534 let payload = &bytes[..];
535 if mtype == MSG_SYNC {
537 let _ = bc.publish_sync(name, payload).await;
538 } else if mtype == MSG_AWARENESS {
539 let _ = bc.publish_awareness(name, payload).await;
540 }
541 }
542 }
543 WorkerEvent::StoreState(bytes) => {
544 tracing::debug!(state_len = bytes.len(), "doc update observed; scheduling store");
545 if let Some(name) = selected_doc_name.as_ref() {
546 state.doc_registry.set_latest(name, bytes.clone());
547 }
548 latest_state_bytes = Some(bytes);
549 let now = Instant::now();
550 pending = true;
551 if first_pending_at.is_none() { first_pending_at = Some(now); }
552 let debounced = now + Duration::from_millis(state.debounce_ms);
553 let cap = first_pending_at.unwrap() + Duration::from_millis(state.max_debounce_ms);
554 let target = if debounced > cap { cap } else { debounced };
555 next_deadline = Some(target);
556 }
557 }
558 }
559 _ = async { if let Some(fut) = sleep_fut { fut.await } } , if next_deadline.is_some() => {
561 if pending {
562 if let (Some(bytes), Some(name)) = (latest_state_bytes.as_ref(), selected_doc_name.as_ref()) {
563 tracing::debug!(document_name = %name, "debounce elapsed; storing state");
564 if let Err(e) = store_bytes(&*state.db, name, bytes).await {
565 tracing::warn!(error = %e, document_name = %name, "failed to store state");
566 }
567 }
568 pending = false;
569 }
570 first_pending_at = None;
571 next_deadline = None;
572 }
573 }
574 }
575}
576
577async fn load_and_send_state<E: DatabaseExtension>(
578 db: &E,
579 name: &str,
580 cmd_tx: &std::sync::mpsc::Sender<WorkerCmd>,
581) -> Result<()> {
582 tracing::debug!(document_name = %name, "fetching state");
583 if let Some(bytes) = db
584 .fetch(FetchContext {
585 document_name: name.to_string(),
586 })
587 .await?
588 {
589 tracing::debug!(document_name = %name, bytes = bytes.len(), "applying fetched state (worker)");
590 let _ = cmd_tx.send(WorkerCmd::ApplyState(bytes));
591 } else {
592 tracing::debug!(document_name = %name, "no stored state; starting empty doc");
593 }
594 Ok(())
595}
596
597async fn store_bytes<E: DatabaseExtension>(db: &E, name: &str, bytes: &[u8]) -> Result<()> {
598 let now = SystemTime::now()
599 .duration_since(UNIX_EPOCH)
600 .unwrap()
601 .as_millis() as i64;
602 db.store(StoreContext {
603 document_name: name.to_string(),
604 state: bytes,
605 updated_at_millis: now,
606 })
607 .await?;
608 Ok(())
609}
610
611#[derive(Clone, Copy, Debug)]
614pub enum AuthScope {
615 ReadOnly,
616 ReadWrite,
617}
618
619pub trait AuthProvider {
620 fn on_authenticate(&self, document_name: &str, token: &str) -> anyhow::Result<AuthScope>;
621}
622
623pub struct StaticTokenAuth {
624 pub token: String,
625 pub scope: AuthScope,
626}
627
628impl AuthProvider for StaticTokenAuth {
629 fn on_authenticate(&self, _document_name: &str, token: &str) -> anyhow::Result<AuthScope> {
630 if token == self.token {
631 Ok(self.scope)
632 } else {
633 anyhow::bail!("permission-denied");
634 }
635 }
636}
637
638async fn send_auth_token_request(socket: &mut WebSocket, name: &str) {
639 let mut out = Vec::new();
640 out.write_string(name);
641 out.write_var(MSG_AUTH);
642 out.write_var(0u32);
644 let _ = socket
645 .send(Message::Binary(axum::body::Bytes::from(out)))
646 .await;
647}
648
649async fn send_auth_authenticated(socket: &mut WebSocket, name: &str, readonly: bool) {
650 let mut out = Vec::new();
651 out.write_string(name);
652 out.write_var(MSG_AUTH);
653 out.write_var(2u32);
655 out.write_string(if readonly { "readonly" } else { "read-write" });
656 let _ = socket
657 .send(Message::Binary(axum::body::Bytes::from(out)))
658 .await;
659}
660
661async fn send_auth_permission_denied(socket: &mut WebSocket, name: &str, reason: Option<&str>) {
662 let mut out = Vec::new();
663 out.write_string(name);
664 out.write_var(MSG_AUTH);
665 out.write_var(1u32);
667 out.write_string(reason.unwrap_or("permission-denied"));
668 let _ = socket
669 .send(Message::Binary(axum::body::Bytes::from(out)))
670 .await;
671}